diff --git a/docs/zh_CN/models/ImageNet1k/README.md b/docs/zh_CN/models/ImageNet1k/README.md index 4044fef11e46d87ad9ad1a86eb4850149edf67ca..a9b2149014d696d8d4078328b92910afa92368c0 100644 --- a/docs/zh_CN/models/ImageNet1k/README.md +++ b/docs/zh_CN/models/ImageNet1k/README.md @@ -50,6 +50,7 @@ - [LeViT 系列](#LeViT) - [TNT 系列](#TNT) - [NextViT 系列](#NextViT) + - [UniFormer 系列](#UniFormer) - [4.2 轻量级模型](#Transformer_lite) - [MobileViT 系列](#MobileViT) - [五、参考文献](#reference) @@ -703,7 +704,7 @@ DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标 **注**:TNT 模型的数据预处理部分 `NormalizeImage` 中的 `mean` 与 `std` 均为 0.5。 -## NextViT 系列 [[35](#ref47)] +## NextViT 系列 [[47](#ref47)] 关于 NextViT 系列模型的精度、速度指标如下表所示,更多介绍可以参考:[NextViT 系列模型文档](NextViT.md)。 | 模型 | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | time(ms)
bs=8 | FLOPs(G) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | @@ -721,6 +722,21 @@ DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标 | NextViT_base_384_ssld | 0.8634 | 0.9806 | - | - | - | 24.27 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_384_ssld_infer.tar) | | NextViT_large_384_ssld | 0.8654 | 0.9814 | - | - | - | 31.53 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_384_ssld_infer.tar) | + + +## UniFormer 系列 [[48](#ref48)] + +关于 UniFormer 系列模型的精度、速度指标如下表所示,更多介绍可以参考:[UniFomer 系列模型文档](UniFormer.md)。 + +| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | time(ms)
bs=8 | FLOPs(G) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | +| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| UniFormer_small | 0.8294 | 0.9631 | - | - | - | 3.44 | 21.55 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/UniFormer_small_infer.tar) | +| UniFormer_small_plus | 0.8329 | 0.9656 | - | - | - | 3.99 | 24.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/UniFormer_small_plus_infer.tar) | +| UniFormer_small_plus_dim64 | 0.8325 | 0.9649 | - | - | - | 3.99 | 24.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus_dim64_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/UniFormer_small_plus_dim64_infer.tar) | +| UniFormer_base | 0.8376 | 0.9672 | - | - |- | 7.77 | 49.78 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/UniFormer_base_infer.tar) | +| UniFormer_base_ls | 0.8398 | 0.9675 | - | - | - | 7.77 | 49.78 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base_ls_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/UniFormer_base_ls_infer.tar) | + + ### 4.2 轻量级模型 @@ -834,4 +850,6 @@ TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE. [46]Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh. CSPNet: A New Backbone that can Enhance Learning Capability of CNN -[46]Jiashi Li, Xin Xia, Wei Li, Huixia Li, Xing Wang, Xuefeng Xiao, Rui Wang, Min Zheng, Xin Pan. Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios. +[47]Jiashi Li, Xin Xia, Wei Li, Huixia Li, Xing Wang, Xuefeng Xiao, Rui Wang, Min Zheng, Xin Pan. Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios. + +[48]Kunchang Li, Yali Wang, Junhao Zhang, Peng Gao, Guanglu Song, Yu Liu, Hongsheng Li, Yu Qiao. UniFormer: Unifying Convolution and Self-attention for Visual Recognition diff --git a/docs/zh_CN/models/ImageNet1k/UniFormer.md b/docs/zh_CN/models/ImageNet1k/UniFormer.md new file mode 100644 index 0000000000000000000000000000000000000000..d5442e4dd507c2be587e39bf71b4caa78732839d --- /dev/null +++ b/docs/zh_CN/models/ImageNet1k/UniFormer.md @@ -0,0 +1,104 @@ +# UniFormer +----- + +## 目录 + +- [1. 模型介绍](#1) + - [1.1 模型简介](#1.1) + - [1.2 模型指标](#1.2) +- [2. 模型快速体验](#2) +- [3. 模型训练、评估和预测](#3) +- [4. 模型推理部署](#4) + - [4.1 推理模型准备](#4.1) + - [4.2 基于 Python 预测引擎推理](#4.2) + - [4.3 基于 C++ 预测引擎推理](#4.3) + - [4.4 服务化部署](#4.4) + - [4.5 端侧部署](#4.5) + - [4.6 Paddle2ONNX 模型转换与预测](#4.6) + + + +## 1. 模型介绍 + + + +### 1.1 模型简介 + +UniFormer 是一种新的视觉 Transformer 网络,可以用作计算机视觉领域的通用骨干网路。作者针对图像识别领域所面临的局部冗余与全局依赖复杂两个问题提出解决办法,设计MHRA(Multi-Head Relation Aggregator)结构在不同特征层使用不同特征学习算子,将convolution和self-attention有机地结合起来,在精度和速度上都有了进一步的提升。[论文地址](https://arxiv.org/abs/2201.09450)。 + + + +### 1.2 模型指标 + +| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPs
(G) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:|:--:|:--:| +| UniFormer_small | 0.8294 | 0.9631 | 0.829 | 0.962 | 3.44 | 21.55 | +| UniFormer_small_plus | 0.8329 | 0.9656 | 0.833 | 0.965 | 3.99 | 24.04 | +| UniFormer_small_plus_dim64 | 0.8325 | 0.9649 | 0.832 | 0.964 | 3.99 | 24.04 | +| UniFormer_base | 0.8376 | 0.9672 | 0.839 | - | 7.77 | 49.78 | +| UniFormer_base_ls | 0.8398 | 0.9675 | 0.839 | 0.967 | 7.77 | 49.78 | + + +**备注:** PaddleClas 所提供的该系列模型的预训练模型权重,均是基于其官方提供的权重转得。 + + + +## 2. 模型快速体验 + +安装 paddlepaddle 和 paddleclas 即可快速对图片进行预测,体验方法可以参考[ResNet50 模型快速体验](./ResNet.md#2-模型快速体验)。 + + + +## 3. 模型训练、评估和预测 + +此部分内容包括训练环境配置、ImageNet数据的准备、该模型在 ImageNet 上的训练、评估、预测等内容。在 `ppcls/configs/ImageNet/UniFormer/` 中提供了该模型的训练配置,启动训练方法可以参考:[ResNet50 模型训练、评估和预测](./ResNet.md#3-模型训练评估和预测)。 + +**备注:** 由于 UniFormer 系列模型默认使用的 GPU 数量为 8 个,所以在训练时,需要指定8个GPU,如`python3 -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c xxx.yaml`, 如果使用 4 个 GPU 训练,默认学习率需要减小一半,精度可能有损。 + + + +## 4. 模型推理部署 + + + +### 4.1 推理模型准备 + +Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍,可以参考[Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)。 + +Inference 的获取可以参考 [ResNet50 推理模型准备](./ResNet.md#41-推理模型准备) 。 + + + +### 4.2 基于 Python 预测引擎推理 + +PaddleClas 提供了基于 python 预测引擎推理的示例。您可以参考[ResNet50 基于 Python 预测引擎推理](./ResNet.md#42-基于-python-预测引擎推理) 。 + + + +### 4.3 基于 C++ 预测引擎推理 + +PaddleClas 提供了基于 C++ 预测引擎推理的示例,您可以参考[服务器端 C++ 预测](../../deployment/image_classification/cpp/linux.md)来完成相应的推理部署。如果您使用的是 Windows 平台,可以参考[基于 Visual Studio 2019 Community CMake 编译指南](../../deployment/image_classification/cpp/windows.md)完成相应的预测库编译和模型预测工作。 + + + +### 4.4 服务化部署 + +Paddle Serving 提供高性能、灵活易用的工业级在线推理服务。Paddle Serving 支持 RESTful、gRPC、bRPC 等多种协议,提供多种异构硬件和多种操作系统环境下推理解决方案。更多关于Paddle Serving 的介绍,可以参考[Paddle Serving 代码仓库](https://github.com/PaddlePaddle/Serving)。 + +PaddleClas 提供了基于 Paddle Serving 来完成模型服务化部署的示例,您可以参考[模型服务化部署](../../deployment/image_classification/paddle_serving.md)来完成相应的部署工作。 + + + +### 4.5 端侧部署 + +Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深度学习推理框架,定位于支持包括移动端、嵌入式以及服务器端在内的多硬件平台。更多关于 Paddle Lite 的介绍,可以参考[Paddle Lite 代码仓库](https://github.com/PaddlePaddle/Paddle-Lite)。 + +PaddleClas 提供了基于 Paddle Lite 来完成模型端侧部署的示例,您可以参考[端侧部署](../../deployment/image_classification/paddle_lite.md)来完成相应的部署工作。 + + + +### 4.6 Paddle2ONNX 模型转换与预测 + +Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考[Paddle2ONNX 代码仓库](https://github.com/PaddlePaddle/Paddle2ONNX)。 + +PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考[Paddle2ONNX 模型转换与预测](../../deployment/image_classification/paddle2onnx.md)来完成相应的部署工作。 diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 5220ea16b7e759cfe8f715cfa5164f2b01e53d64..d7159fd80682480e3695d6c491620b1ccad44ae0 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -79,6 +79,7 @@ from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 from .model_zoo.wideresnet import WideResNet +from .model_zoo.uniformer import UniFormer_small, UniFormer_small_plus, UniFormer_small_plus_dim64, UniFormer_base, UniFormer_base_ls # help whl get all the models' api (class type) and components' api (func type) diff --git a/ppcls/arch/backbone/model_zoo/uniformer.py b/ppcls/arch/backbone/model_zoo/uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa7dc9c7f66997da7dc072a824a40f8e66a2f3a7 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/uniformer.py @@ -0,0 +1,552 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/bytedance/UniFormer/blob/main/classification/uniformer.py +# reference: https://arxiv.org/abs/2201.09450 + +from collections import OrderedDict +from functools import partial +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import math +from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity, Mlp + +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "UniFormer_small": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_pretrained.pdparams", + "UniFormer_small_plus": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus_pretrained.pdparams", + "UniFormer_small_plus_dim64": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus_dim64_pretrained.pdparams", + "UniFormer_base": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base_pretrained.pdparams", + "UniFormer_base_ls": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base_ls_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +layer_scale = False +init_value = 1e-6 + + +class CMlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_conv = nn.Conv2D(in_features, hidden_features, 1) + self.act = act_layer() + self.fc2_conv = nn.Conv2D(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1_conv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2_conv(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape( + shape=[B, N, 3, self.num_heads, C // self.num_heads]).transpose( + perm=[2, 0, 3, 1, 4]) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @k.transpose(perm=[0, 1, 3, 2])) * self.scale + attn = nn.Softmax(axis=-1)(attn) + attn = self.attn_drop(attn) + + x = (attn @v).transpose(perm=[0, 2, 1, 3]).reshape(shape=[B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CBlock(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2D(dim, dim, 3, padding=1, groups=dim) + self.norm1 = nn.BatchNorm2D(dim) + self.conv1 = nn.Conv2D(dim, dim, 1) + self.conv2 = nn.Conv2D(dim, dim, 1) + self.attn = nn.Conv2D(dim, dim, 5, padding=2, groups=dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.BatchNorm2D(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = CMlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + x = x + self.drop_path( + self.conv2(self.attn(self.conv1(self.norm1(x))))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class SABlock(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2D(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + global layer_scale + self.ls = layer_scale + if self.ls: + global init_value + print(f"Use layer_scale: {layer_scale}, init_values: {init_value}") + self.gamma_1 = self.create_parameter( + [dim], + dtype='float32', + default_initializer=nn.initializer.Constant(value=init_value)) + self.gamma_2 = self.create_parameter( + [dim], + dtype='float32', + default_initializer=nn.initializer.Constant(value=init_value)) + + def forward(self, x): + x = x + self.pos_embed(x) + B, N, H, W = x.shape + x = x.flatten(2).transpose(perm=[0, 2, 1]) + if self.ls: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.transpose(perm=[0, 2, 1]).reshape(shape=[B, N, H, W]) + return x + + +class HeadEmbedding(nn.Layer): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2D( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + nn.BatchNorm2D(out_channels // 2), + nn.GELU(), + nn.Conv2D( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + nn.BatchNorm2D(out_channels)) + + def forward(self, x): + x = self.proj(x) + return x + + +class MiddleEmbedding(nn.Layer): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2D( + in_channels, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + nn.BatchNorm2D(out_channels)) + + def forward(self, x): + x = self.proj(x) + return x + + +class PatchEmbed(nn.Layer): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // + patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.norm = nn.LayerNorm(embed_dim) + self.proj_conv = nn.Conv2D( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj_conv(x) + B, C, H, W = x.shape + x = x.flatten(2).transpose(perm=[0, 2, 1]) + x = self.norm(x) + x = x.reshape(shape=[B, H, W, C]).transpose(perm=[0, 3, 1, 2]) + return x + + +class UniFormer(nn.Layer): + """ UniFormer + A PaddlePaddle impl of : `UniFormer: Unifying Convolution and Self-attention for Visual Recognition` - + https://arxiv.org/abs/2201.09450 + """ + + def __init__(self, + depth=[3, 4, 8, 3], + img_size=224, + in_chans=3, + class_num=1000, + embed_dim=[64, 128, 320, 512], + head_dim=64, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + representation_size=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=None, + conv_stem=False): + """ + Args: + depth (list): depth of each stage + img_size (int, tuple): input image size + in_chans (int): number of input channels + class_num (int): number of classes for classification head + embed_dim (list): embedding dimension of each stage + head_dim (int): head dimension + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer (nn.Module): normalization layer + conv_stem (bool): whether use overlapped patch stem + """ + super().__init__() + self.class_num = class_num + self.num_features = self.embed_dim = embed_dim + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + if conv_stem: + self.patch_embed1 = HeadEmbedding( + in_channels=in_chans, out_channels=embed_dim[0]) + self.patch_embed2 = MiddleEmbedding( + in_channels=embed_dim[0], out_channels=embed_dim[1]) + self.patch_embed3 = MiddleEmbedding( + in_channels=embed_dim[1], out_channels=embed_dim[2]) + self.patch_embed4 = MiddleEmbedding( + in_channels=embed_dim[2], out_channels=embed_dim[3]) + else: + self.patch_embed1 = PatchEmbed( + img_size=img_size, + patch_size=4, + in_chans=in_chans, + embed_dim=embed_dim[0]) + self.patch_embed2 = PatchEmbed( + img_size=img_size // 4, + patch_size=2, + in_chans=embed_dim[0], + embed_dim=embed_dim[1]) + self.patch_embed3 = PatchEmbed( + img_size=img_size // 8, + patch_size=2, + in_chans=embed_dim[1], + embed_dim=embed_dim[2]) + self.patch_embed4 = PatchEmbed( + img_size=img_size // 16, + patch_size=2, + in_chans=embed_dim[2], + embed_dim=embed_dim[3]) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [ + x.item() for x in paddle.linspace(0, drop_path_rate, sum(depth)) + ] # stochastic depth decay rule + num_heads = [dim // head_dim for dim in embed_dim] + self.blocks1 = nn.LayerList([ + CBlock( + dim=embed_dim[0], + num_heads=num_heads[0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) for i in range(depth[0]) + ]) + self.blocks2 = nn.LayerList([ + CBlock( + dim=embed_dim[1], + num_heads=num_heads[1], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i + depth[0]], + norm_layer=norm_layer) for i in range(depth[1]) + ]) + self.blocks3 = nn.LayerList([ + SABlock( + dim=embed_dim[2], + num_heads=num_heads[2], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i + depth[0] + depth[1]], + norm_layer=norm_layer) for i in range(depth[2]) + ]) + self.blocks4 = nn.LayerList([ + SABlock( + dim=embed_dim[3], + num_heads=num_heads[3], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i + depth[0] + depth[1] + depth[2]], + norm_layer=norm_layer) for i in range(depth[3]) + ]) + self.norm = nn.BatchNorm2D(embed_dim[-1]) + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential( + OrderedDict([('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh())])) + else: + self.pre_logits = nn.Identity() + + # Classifier head + self.head = nn.Linear(embed_dim[-1], + class_num) if class_num > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward_features(self, x): + x = self.patch_embed1(x) + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + x = self.patch_embed2(x) + for blk in self.blocks2: + x = blk(x) + x = self.patch_embed3(x) + for blk in self.blocks3: + x = blk(x) + x = self.patch_embed4(x) + for blk in self.blocks4: + x = blk(x) + x = self.norm(x) + x = self.pre_logits(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = x.flatten(2).mean(-1) + x = self.head(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def UniFormer_small(pretrained=True, use_ssld=False, **kwargs): + model = UniFormer( + depth=[3, 4, 8, 3], + embed_dim=[64, 128, 320, 512], + head_dim=64, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop_path_rate=0.1, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["UniFormer_small"], use_ssld=use_ssld) + return model + + +def UniFormer_small_plus(pretrained=True, use_ssld=False, **kwargs): + model = UniFormer( + depth=[3, 5, 9, 3], + conv_stem=True, + embed_dim=[64, 128, 320, 512], + head_dim=32, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop_path_rate=0.1, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["UniFormer_small_plus"], + use_ssld=use_ssld) + return model + + +def UniFormer_small_plus_dim64(pretrained=True, use_ssld=False, **kwargs): + model = UniFormer( + depth=[3, 5, 9, 3], + conv_stem=True, + embed_dim=[64, 128, 320, 512], + head_dim=64, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop_path_rate=0.1, + **kwargs) + _load_pretrained( + pretrained, + model, + MODEL_URLS["UniFormer_small_plus_dim64"], + use_ssld=use_ssld) + return model + + +def UniFormer_base(pretrained=True, use_ssld=False, **kwargs): + model = UniFormer( + depth=[5, 8, 20, 7], + embed_dim=[64, 128, 320, 512], + head_dim=64, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop_path_rate=0.3, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["UniFormer_base"], use_ssld=use_ssld) + return model + + +def UniFormer_base_ls(pretrained=True, use_ssld=False, **kwargs): + global layer_scale + layer_scale = True + model = UniFormer( + depth=[5, 8, 20, 7], + embed_dim=[64, 128, 320, 512], + head_dim=64, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + drop_path_rate=0.3, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["UniFormer_base_ls"], use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml b/ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a374e8858da2e04a75ea817b9559bd7fff1f1feb --- /dev/null +++ b/ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: UniFormer_base + class_num: 1000 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: pos_embed cls_token .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml b/ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ab3ae6cf33deb2c31b566cf8561de4d7352fb3d --- /dev/null +++ b/ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: UniFormer_base_ls + class_num: 1000 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: pos_embed cls_token .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml b/ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a46c1c48b3fabe4c2b3400585b52c6a7188f764b --- /dev/null +++ b/ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: UniFormer_small + class_num: 1000 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: pos_embed cls_token .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml b/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d43bd477a484c790f7e196a868a89518ea147e82 --- /dev/null +++ b/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: UniFormer_small_plus + class_num: 1000 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: pos_embed cls_token .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml b/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84b44610fa43be490b1023c4ca597d6274e19977 --- /dev/null +++ b/ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml @@ -0,0 +1,162 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: UniFormer_small_plus_dim64 + class_num: 1000 + pretrained: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.05 + no_weight_decay_name: pos_embed cls_token .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 5 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 248 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/test_tipc/configs/UniFormer/UniFormer_base_ls_train_infer_python.txt b/test_tipc/configs/UniFormer/UniFormer_base_ls_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..da7344b2e65c6bec1338373301263481a441f976 --- /dev/null +++ b/test_tipc/configs/UniFormer/UniFormer_base_ls_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:UniFormer_base_ls +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=5|whole_train_whole_infer=300 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:1024 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base_ls.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base_ls.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/UniFormer/UniFormer_base_train_infer_python.txt b/test_tipc/configs/UniFormer/UniFormer_base_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..22b9395d5f2bb960b271753d0f8714f02f3b7a9c --- /dev/null +++ b/test_tipc/configs/UniFormer/UniFormer_base_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:UniFormer_base +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=5|whole_train_whole_infer=300 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:1024 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_base.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_base.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/UniFormer/UniFormer_small_plus_dim64_train_infer_python.txt b/test_tipc/configs/UniFormer/UniFormer_small_plus_dim64_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..b40d24843cf1a0f2a504fe91ab2d26eb6af16c25 --- /dev/null +++ b/test_tipc/configs/UniFormer/UniFormer_small_plus_dim64_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:UniFormer_small_plus_dim64 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=5|whole_train_whole_infer=300 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:1024 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus_dim64.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus_dim64.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/UniFormer/UniFormer_small_plus_train_infer_python.txt b/test_tipc/configs/UniFormer/UniFormer_small_plus_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..81ddad8382bfb7abbd70117ac500a346c569804e --- /dev/null +++ b/test_tipc/configs/UniFormer/UniFormer_small_plus_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:UniFormer_small_plus +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=5|whole_train_whole_infer=300 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:1024 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small_plus.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small_plus.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/UniFormer/UniFormer_small_train_infer_python.txt b/test_tipc/configs/UniFormer/UniFormer_small_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fb4cc493d0e807e4e5f8cf000e0087648894710 --- /dev/null +++ b/test_tipc/configs/UniFormer/UniFormer_small_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:UniFormer_small +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=5|whole_train_whole_infer=300 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:1024 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/UniFormer/UniFormer_small.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/UniFormer_small.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=248 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}]