提交 ec2029b8 编写于 作者: C cuicheng01 提交者: cuicheng01

Add NextViT code and docs

上级 509202a3
# NextViT
-----
## 目录
- [1. 模型介绍](#1)
- [1.1 模型简介](#1.1)
- [1.2 模型指标](#1.2)
- [2. 模型快速体验](#2)
- [3. 模型训练、评估和预测](#3)
- [4. 模型推理部署](#4)
- [4.1 推理模型准备](#4.1)
- [4.2 基于 Python 预测引擎推理](#4.2)
- [4.3 基于 C++ 预测引擎推理](#4.3)
- [4.4 服务化部署](#4.4)
- [4.5 端侧部署](#4.5)
- [4.6 Paddle2ONNX 模型转换与预测](#4.6)
<a name='1'></a>
## 1. 模型介绍
<a name='1.1'></a>
### 1.1 模型简介
NextViT 是一种新的视觉 Transformer 网络,可以用作计算机视觉领域的通用骨干网络。作者提出了在现实工业场景中有效部署的 Next generation Vision Transformer,即 Next-ViT,从延迟/准确性权衡的角度来看,它在 CNN 和 ViT 中均占主导地位。在这项工作中,作者分别开发了Next Convolution Block(NCB)和Next Transformer Block(NTB),以通过部署友好的机制捕获局部和全局信息。在此基础上,Next Hybrid Strategy (NHS) 旨在以高效的混合范式堆叠 NCB 和 NTB,从而提高各种下游任务的性能。
最终,NextViT 在多项任务中达到SOTA效果。[论文地址](https://arxiv.org/pdf/2207.05501.pdf)
<a name='1.2'></a>
### 1.2 模型指标
| Models | Top1 | Top5 | Reference<br>top1 | Reference<br>top5 | FLOPs<br>(G) | Params<br>(M) |
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
| NextViT_small_224 | 0.8248 | 0.9616 | 0.825 | - | 5.79 | 31.80 |
| NextViT_base_224 | 0.8324 | 0.9658 | 0.832 | - | 8.26 | 44.88 |
| NextViT_large_224 | 0.8363 | 0.9661 | 0.836 | - | 10.73 | 57.95 |
| NextViT_small_384 | 0.8401 | 0.9698 | 0.836 | - | 17.00 | 31.80 |
| NextViT_base_384 | 0.8465 | 0.9723 | 0.843 | - |24.27 | 44.88 |
| NextViT_large_384 | 0.8492 | 0.9728 | 0.847 | - | 31.53 | 57.95 |
| NextViT_small_224_ssld | 0.8472 | 0.9734 | 0.848 | - | 5.79 | 31.80 |
| NextViT_base_224_ssld | 0.8500 | 0.9753 | 0.851 | - | 8.26 | 44.88 |
| NextViT_large_224_ssld | 0.8536 | 0.9762 | 0.854 | - | 10.73 | 57.95 |
| NextViT_small_384_ssld | 0.8597 | 0.9790 | 0.858 | - | 17.00 | 31.80 |
| NextViT_base_384_ssld | 0.8634 | 0.9806 | 0.861 | - |24.27 | 44.88 |
| NextViT_large_384_ssld | 0.8654 | 0.9814 | 0.864 | - | 31.53 | 57.95 |
**备注:**
- PaddleClas 所提供的该系列模型的预训练模型权重,均是基于其官方提供的权重转得。PaddleClas 验证了 NextViT_small_224 的精度可以与论文精度对齐。
- 此处 `_ssld` 并非使用 PaddleClas 中的蒸馏的`SSLD 蒸馏`方法得到,而是使用类似`SSLD 蒸馏`挖掘的数据集训练得到。
<a name="2"></a>
## 2. 模型快速体验
安装 paddlepaddle 和 paddleclas 即可快速对图片进行预测,体验方法可以参考[ResNet50 模型快速体验](./ResNet.md#2-模型快速体验)
<a name="3"></a>
## 3. 模型训练、评估和预测
此部分内容包括训练环境配置、ImageNet数据的准备、该模型在 ImageNet 上的训练、评估、预测等内容。在 `ppcls/configs/ImageNet/NextViT/` 中提供了该模型的训练配置,启动训练方法可以参考:[ResNet50 模型训练、评估和预测](./ResNet.md#3-模型训练评估和预测)
**备注:** 由于 NextViT 系列模型默认使用的 GPU 数量为 8 个,所以在训练时,需要指定8个GPU,如`python3 -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c xxx.yaml`, 如果使用 4 个 GPU 训练,默认学习率需要减小一半,精度可能有损。
<a name="4"></a>
## 4. 模型推理部署
<a name="4.1"></a>
### 4.1 推理模型准备
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍,可以参考[Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)
Inference 的获取可以参考 [ResNet50 推理模型准备](./ResNet.md#41-推理模型准备)
<a name="4.2"></a>
### 4.2 基于 Python 预测引擎推理
PaddleClas 提供了基于 python 预测引擎推理的示例。您可以参考[ResNet50 基于 Python 预测引擎推理](./ResNet.md#42-基于-python-预测引擎推理)
<a name="4.3"></a>
### 4.3 基于 C++ 预测引擎推理
PaddleClas 提供了基于 C++ 预测引擎推理的示例,您可以参考[服务器端 C++ 预测](../../deployment/image_classification/cpp/linux.md)来完成相应的推理部署。如果您使用的是 Windows 平台,可以参考[基于 Visual Studio 2019 Community CMake 编译指南](../../deployment/image_classification/cpp/windows.md)完成相应的预测库编译和模型预测工作。
<a name="4.4"></a>
### 4.4 服务化部署
Paddle Serving 提供高性能、灵活易用的工业级在线推理服务。Paddle Serving 支持 RESTful、gRPC、bRPC 等多种协议,提供多种异构硬件和多种操作系统环境下推理解决方案。更多关于Paddle Serving 的介绍,可以参考[Paddle Serving 代码仓库](https://github.com/PaddlePaddle/Serving)
PaddleClas 提供了基于 Paddle Serving 来完成模型服务化部署的示例,您可以参考[模型服务化部署](../../deployment/image_classification/paddle_serving.md)来完成相应的部署工作。
<a name="4.5"></a>
### 4.5 端侧部署
Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深度学习推理框架,定位于支持包括移动端、嵌入式以及服务器端在内的多硬件平台。更多关于 Paddle Lite 的介绍,可以参考[Paddle Lite 代码仓库](https://github.com/PaddlePaddle/Paddle-Lite)
PaddleClas 提供了基于 Paddle Lite 来完成模型端侧部署的示例,您可以参考[端侧部署](../../deployment/image_classification/paddle_lite.md)来完成相应的部署工作。
<a name="4.6"></a>
### 4.6 Paddle2ONNX 模型转换与预测
Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考[Paddle2ONNX 代码仓库](https://github.com/PaddlePaddle/Paddle2ONNX)
PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考[Paddle2ONNX 模型转换与预测](../../deployment/image_classification/paddle2onnx.md)来完成相应的部署工作。
......@@ -49,6 +49,7 @@
- [PVTV2 系列](#PVTV2)
- [LeViT 系列](#LeViT)
- [TNT 系列](#TNT)
- [NextViT 系列](#NextViT)
- [4.2 轻量级模型](#Transformer_lite)
- [MobileViT 系列](#MobileViT)
- [五、参考文献](#reference)
......@@ -701,6 +702,26 @@ DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标
**注**:TNT 模型的数据预处理部分 `NormalizeImage` 中的 `mean``std` 均为 0.5。
<a name="NextViT"></a>
## NextViT 系列 <sup>[[35](#ref47)]</sup>
关于 NextViT 系列模型的精度、速度指标如下表所示,更多介绍可以参考:[NextViT 系列模型文档](NextViT.md)
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)<br>bs=1 | time(ms)<br>bs=4 | FLOPs(G) | Params(M) |预训练模型下载地址 | inference模型下载地址 |
| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| NextViT_small_224 | 0.8248 | 0.9616 | - | - | 5.79 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_224_infer.tar) |
| NextViT_base_224 | 0.8324 | 0.9658 | - | - | 8.26 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_224_infer.tar) |
| NextViT_large_224 | 0.8363 | 0.9661 | - | - | 10.73 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_224_infer.tar) |
| NextViT_small_384 | 0.8401 | 0.9698 | - | - | 17.00 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_384_infer.tar) |
| NextViT_base_384 | 0.8465 | 0.9723 | - | - |24.27 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_384_infer.tar) |
| NextViT_large_384 | 0.8492 | 0.9728 | - | - | 31.53 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_384_infer.tar) |
| NextViT_small_224_ssld | 0.8472 | 0.9734 | - | - | 5.79 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_224_ssld_infer.tar) |
| NextViT_base_224_ssld | 0.8500 | 0.9753 | - | - | 8.26 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_224_ssld_infer.tar) |
| NextViT_large_224_ssld | 0.8536 | 0.9762 | - | - | 10.73 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_224_ssld_infer.tar) |
| NextViT_small_384_ssld | 0.8597 | 0.9790 | - | - | 17.00 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_384_ssld_infer.tar) |
| NextViT_base_384_ssld | 0.8634 | 0.9806 | - | - |24.27 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_384_ssld_infer.tar) |
| NextViT_large_384_ssld | 0.8654 | 0.9814 | - | - | 31.53 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_384_ssld_infer.tar) |
<a name="Transformer_lite"></a>
### 4.2 轻量级模型
......@@ -813,3 +834,5 @@ TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE.
<a name="ref45">[45]</a>Robert J. Wang, Xiang Li, Charles X. Ling. Pelee: A Real-Time Object Detection System on Mobile Devices
<a name="ref46">[46]</a>Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh. CSPNet: A New Backbone that can Enhance Learning Capability of CNN
<a name="ref47">[46]</a>Jiashi Li, Xin Xia, Wei Li, Huixia Li, Xing Wang, Xuefeng Xiao, Rui Wang, Min Zheng, Xin Pan. Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios.
......@@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG
from .model_zoo.van import VAN_B0
from .model_zoo.peleenet import PeleeNet
from .model_zoo.convnext import ConvNeXt_tiny
from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://github.com/bytedance/Next-ViT/blob/main/classification/nextvit.py
# reference: https://arxiv.org/abs/2207.05501
from functools import partial
import paddle
from paddle import nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"NextViT_small_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams",
"NextViT_base_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams",
"NextViT_large_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams",
"NextViT_small_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams",
"NextViT_base_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams",
"NextViT_large_384":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams",
}
__all__ = list(MODEL_URLS.keys())
NORM_EPS = 1e-5
def rearrange(x, pattern, **axes_lengths):
if 'b (h w) c -> b c h w' == pattern:
b, n, c = x.shape
h = axes_lengths.pop('h', -1)
w = axes_lengths.pop('w', -1)
h = h if w == -1 else n // w
w = w if h == -1 else n // h
return x.transpose([0, 2, 1]).reshape([b, c, h, w])
if 'b c h w -> b (h w) c' == pattern:
b, c, h, w = x.shape
return x.reshape([b, c, h * w]).transpose([0, 2, 1])
if 'b t (h d) -> b h t d' == pattern:
b, t, h_d = x.shape
h = axes_lengths['h']
return x.reshape([b, t, h, h_d // h]).transpose([0, 2, 1, 3])
if 'b h t d -> b t (h d)' == pattern:
b, h, t, d = x.shape
return x.transpose([0, 2, 1, 3]).reshape([b, t, h * d])
raise NotImplementedError(
"Rearrangement '{}' has not been implemented.".format(pattern))
def merge_pre_bn(layer, pre_bn_1, pre_bn_2=None):
""" Merge pre BN to reduce inference runtime.
"""
weight = layer.weight
if isinstance(layer, nn.Linear):
weight = weight.transpose([1, 0])
bias = layer.bias
if pre_bn_2 is None:
scale_invstd = (pre_bn_1._variance + pre_bn_1._epsilon).pow(-0.5)
extra_weight = scale_invstd * pre_bn_1.weight
extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1._mean * scale_invstd
else:
scale_invstd_1 = (pre_bn_1._variance + pre_bn_1._epsilon).pow(-0.5)
scale_invstd_2 = (pre_bn_2._variance + pre_bn_2._epsilon).pow(-0.5)
extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
extra_bias = scale_invstd_2 * pre_bn_2.weight * (
pre_bn_1.bias - pre_bn_1.weight * pre_bn_1._mean * scale_invstd_1 -
pre_bn_2._mean) + pre_bn_2.bias
if isinstance(layer, nn.Linear):
extra_bias = weight @extra_bias
weight = weight.multiply(
extra_weight.reshape([1, weight.shape[1]]).expand_as(weight))
weight = weight.transpose([1, 0])
elif isinstance(layer, nn.Conv2D):
assert weight.shape[2] == 1 and weight.shape[3] == 1
weight = weight.reshape([weight.shape[0], weight.shape[1]])
extra_bias = weight @extra_bias
weight = weight.multiply(
extra_weight.reshape([1, weight.shape[1]]).expand_as(weight))
weight = weight.reshape([weight.shape[0], weight.shape[1], 1, 1])
bias = bias.add(extra_bias)
layer.weight.set_value(weight)
layer.bias.set_value(bias)
def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1,
groups=groups,
bias_attr=False)
self.norm = nn.BatchNorm2D(out_channels, epsilon=NORM_EPS)
self.act = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
class PatchEmbed(nn.Layer):
def __init__(self, in_channels, out_channels, stride=1):
super(PatchEmbed, self).__init__()
norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS)
if stride == 2:
self.avgpool = nn.AvgPool2D((2, 2), stride=2, ceil_mode=True)
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=1,
stride=1,
bias_attr=False)
self.norm = norm_layer(out_channels)
elif in_channels != out_channels:
self.avgpool = nn.Identity()
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=1,
stride=1,
bias_attr=False)
self.norm = norm_layer(out_channels)
else:
self.avgpool = nn.Identity()
self.conv = nn.Identity()
self.norm = nn.Identity()
def forward(self, x):
return self.norm(self.conv(self.avgpool(x)))
class MHCA(nn.Layer):
"""
Multi-Head Convolutional Attention
"""
def __init__(self, out_channels, head_dim):
super(MHCA, self).__init__()
norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS)
self.group_conv3x3 = nn.Conv2D(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=out_channels // head_dim,
bias_attr=False)
self.norm = norm_layer(out_channels)
self.act = nn.ReLU()
self.projection = nn.Conv2D(
out_channels, out_channels, kernel_size=1, bias_attr=False)
def forward(self, x):
out = self.group_conv3x3(x)
out = self.norm(out)
out = self.act(out)
out = self.projection(out)
return out
class Mlp(nn.Layer):
def __init__(self,
in_features,
out_features=None,
mlp_ratio=None,
drop=0.,
bias=True):
super().__init__()
out_features = out_features or in_features
hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
self.conv1 = nn.Conv2D(
in_features,
hidden_dim,
kernel_size=1,
bias_attr=None if bias == True else False)
self.act = nn.ReLU()
self.conv2 = nn.Conv2D(
hidden_dim,
out_features,
kernel_size=1,
bias_attr=None if bias == True else False)
self.drop = nn.Dropout(drop)
def merge_bn(self, pre_norm):
merge_pre_bn(self.conv1, pre_norm)
self.is_bn_merged = True
def forward(self, x):
x = self.conv1(x)
x = self.act(x)
x = self.drop(x)
x = self.conv2(x)
x = self.drop(x)
return x
class NCB(nn.Layer):
"""
Next Convolution Block
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
path_dropout=0.0,
drop=0.0,
head_dim=32,
mlp_ratio=3):
super(NCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS)
assert out_channels % head_dim == 0
self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
self.mhca = MHCA(out_channels, head_dim)
self.attention_path_dropout = DropPath(path_dropout)
self.norm = norm_layer(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
self.mlp_path_dropout = DropPath(path_dropout)
self.is_bn_merged = False
def merge_bn(self):
if not self.is_bn_merged:
self.mlp.merge_bn(self.norm)
self.is_bn_merged = True
def forward(self, x):
x = self.patch_embed(x)
x = x + self.attention_path_dropout(self.mhca(x))
if not self.is_bn_merged:
out = self.norm(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x
class E_MHSA(nn.Layer):
"""
Efficient Multi-Head Self Attention
"""
def __init__(self,
dim,
out_dim=None,
head_dim=32,
qkv_bias=True,
qk_scale=None,
attn_drop=0,
proj_drop=0.,
sr_ratio=1):
super().__init__()
self.dim = dim
self.out_dim = out_dim if out_dim is not None else dim
self.num_heads = self.dim // head_dim
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, self.dim, bias_attr=qkv_bias)
self.k = nn.Linear(dim, self.dim, bias_attr=qkv_bias)
self.v = nn.Linear(dim, self.dim, bias_attr=qkv_bias)
self.proj = nn.Linear(self.dim, self.out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
self.N_ratio = sr_ratio**2
if sr_ratio > 1:
self.sr = nn.AvgPool1D(
kernel_size=self.N_ratio, stride=self.N_ratio)
self.norm = nn.BatchNorm1D(dim, epsilon=NORM_EPS)
self.is_bn_merged = False
def merge_bn(self, pre_bn):
merge_pre_bn(self.q, pre_bn)
if self.sr_ratio > 1:
merge_pre_bn(self.k, pre_bn, self.norm)
merge_pre_bn(self.v, pre_bn, self.norm)
else:
merge_pre_bn(self.k, pre_bn)
merge_pre_bn(self.v, pre_bn)
self.is_bn_merged = True
def forward(self, x):
B, N, C = x.shape
q = self.q(x)
q = q.reshape(
[B, N, self.num_heads, int(C // self.num_heads)]).transpose(
[0, 2, 1, 3])
if self.sr_ratio > 1:
x_ = x.transpose([0, 2, 1])
x_ = self.sr(x_)
if not self.is_bn_merged:
x_ = self.norm(x_)
x_ = x_.transpose([0, 2, 1])
k = self.k(x_)
k = k.reshape(
[B, k.shape[1], self.num_heads, int(C // self.num_heads)
]).transpose([0, 2, 3, 1])
v = self.v(x_)
v = v.reshape(
[B, v.shape[1], self.num_heads, int(C // self.num_heads)
]).transpose([0, 2, 1, 3])
else:
k = self.k(x)
k = k.reshape(
[B, k.shape[1], self.num_heads, int(C // self.num_heads)
]).transpose([0, 2, 3, 1])
v = self.v(x)
v = v.reshape(
[B, v.shape[1], self.num_heads, int(C // self.num_heads)
]).transpose([0, 2, 1, 3])
attn = (q @k) * self.scale
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class NTB(nn.Layer):
"""
Next Transformer Block
"""
def __init__(
self,
in_channels,
out_channels,
path_dropout,
stride=1,
sr_ratio=1,
mlp_ratio=2,
head_dim=32,
mix_block_ratio=0.75,
attn_drop=0.0,
drop=0.0, ):
super(NTB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mix_block_ratio = mix_block_ratio
norm_func = partial(nn.BatchNorm2D, epsilon=NORM_EPS)
self.mhsa_out_channels = _make_divisible(
int(out_channels * mix_block_ratio), 32)
self.mhca_out_channels = out_channels - self.mhsa_out_channels
self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels,
stride)
self.norm1 = norm_func(self.mhsa_out_channels)
self.e_mhsa = E_MHSA(
self.mhsa_out_channels,
head_dim=head_dim,
sr_ratio=sr_ratio,
attn_drop=attn_drop,
proj_drop=drop)
self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)
self.projection = PatchEmbed(
self.mhsa_out_channels, self.mhca_out_channels, stride=1)
self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))
self.norm2 = norm_func(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
self.mlp_path_dropout = DropPath(path_dropout)
self.is_bn_merged = False
def merge_bn(self):
if not self.is_bn_merged:
self.e_mhsa.merge_bn(self.norm1)
self.mlp.merge_bn(self.norm2)
self.is_bn_merged = True
def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
if not self.is_bn_merged:
out = self.norm1(x)
else:
out = x
out = rearrange(out, "b c h w -> b (h w) c") # b n c
out = self.e_mhsa(out)
out = self.mhsa_path_dropout(out)
x = x + rearrange(out, "b (h w) c -> b c h w", h=H)
out = self.projection(x)
out = out + self.mhca_path_dropout(self.mhca(out))
x = paddle.concat([x, out], axis=1)
if not self.is_bn_merged:
out = self.norm2(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x
class NextViT(nn.Layer):
def __init__(self,
stem_chs,
depths,
path_dropout,
attn_drop=0,
drop=0,
class_num=1000,
strides=[1, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
head_dim=32,
mix_block_ratio=0.75):
super(NextViT, self).__init__()
self.stage_out_channels = [
[96] * (depths[0]), [192] * (depths[1] - 1) + [256],
[384, 384, 384, 384, 512] * (depths[2] // 5),
[768] * (depths[3] - 1) + [1024]
]
# Next Hybrid Strategy
self.stage_block_types = [[NCB] * depths[0],
[NCB] * (depths[1] - 1) + [NTB],
[NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
[NCB] * (depths[3] - 1) + [NTB]]
self.stem = nn.Sequential(
ConvBNReLU(
3, stem_chs[0], kernel_size=3, stride=2),
ConvBNReLU(
stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
ConvBNReLU(
stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
ConvBNReLU(
stem_chs[2], stem_chs[2], kernel_size=3, stride=2), )
input_channel = stem_chs[-1]
features = []
idx = 0
dpr = [
x.item() for x in paddle.linspace(0, path_dropout, sum(depths))
] # stochastic depth decay rule
for stage_id in range(len(depths)):
numrepeat = depths[stage_id]
output_channels = self.stage_out_channels[stage_id]
block_types = self.stage_block_types[stage_id]
for block_id in range(numrepeat):
if strides[stage_id] == 2 and block_id == 0:
stride = 2
else:
stride = 1
output_channel = output_channels[block_id]
block_type = block_types[block_id]
if block_type is NCB:
layer = NCB(input_channel,
output_channel,
stride=stride,
path_dropout=dpr[idx + block_id],
drop=drop,
head_dim=head_dim)
features.append(layer)
elif block_type is NTB:
layer = NTB(input_channel,
output_channel,
path_dropout=dpr[idx + block_id],
stride=stride,
sr_ratio=sr_ratios[stage_id],
head_dim=head_dim,
mix_block_ratio=mix_block_ratio,
attn_drop=attn_drop,
drop=drop)
features.append(layer)
input_channel = output_channel
idx += numrepeat
self.features = nn.Sequential(*features)
self.norm = nn.BatchNorm2D(output_channel, epsilon=NORM_EPS)
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
self.proj_head = nn.Sequential(nn.Linear(output_channel, class_num), )
self.stage_out_idx = [
sum(depths[:idx + 1]) - 1 for idx in range(len(depths))
]
self._initialize_weights()
def merge_bn(self):
self.eval()
for idx, layer in self.named_sublayers():
if isinstance(layer, NCB) or isinstance(layer, NTB):
layer.merge_bn()
def _initialize_weights(self):
for n, m in self.named_sublayers():
if isinstance(m, (nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm,
nn.BatchNorm1D)):
ones_(m.weight)
zeros_(m.bias)
elif isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.Conv2D):
trunc_normal_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
zeros_(m.bias)
def forward(self, x):
x = self.stem(x)
for layer in self.features:
x = layer(x)
x = self.norm(x)
x = self.avgpool(x)
x = paddle.flatten(x, 1)
x = self.proj_head(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def NextViT_small_224(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 10, 3],
path_dropout=0.1,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_small_224"], use_ssld=use_ssld)
return model
def NextViT_base_224(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 20, 3],
path_dropout=0.2,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_base_224"], use_ssld=use_ssld)
return model
def NextViT_large_224(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 30, 3],
path_dropout=0.2,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_large_224"], use_ssld=use_ssld)
return model
def NextViT_small_384(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 10, 3],
path_dropout=0.1,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_small_384"], use_ssld=use_ssld)
return model
def NextViT_base_384(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 20, 3],
path_dropout=0.2,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_base_384"], use_ssld=use_ssld)
return model
def NextViT_large_384(pretrained=False, use_ssld=False, **kwargs):
model = NextViT(
stem_chs=[64, 32, 64],
depths=[3, 4, 30, 3],
path_dropout=0.2,
**kwargs)
_load_pretrained(
pretrained, model, MODEL_URLS["NextViT_large_384"], use_ssld=use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_base_224
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_base_384
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 384
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_large_224
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_large_384
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 384
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_small_224
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
interpolation: bicubic
backend: pil
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 300
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
# model architecture
Arch:
name: NextViT_small_384
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.1
no_weight_decay_name: .bias norm
one_dim_param_no_weight_decay: True
lr:
# for 8 cards
name: Cosine
learning_rate: 1e-3
eta_min: 1e-5
warmup_epoch: 20
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 384
interpolation: bicubic
backend: pil
- RandFlipImage:
flip_code: 1
- TimmAutoAugment:
config_str: rand-m9-mstd0.5-inc1
interpolation: bicubic
img_size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.25
sl: 0.02
sh: 1.0/3.0
r1: 0.3
attempt: 10
use_log_aspect: True
mode: pixel
batch_transform_ops:
- OpSampler:
MixupOperator:
alpha: 0.8
prob: 0.5
CutmixOperator:
alpha: 1.0
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 384
interpolation: bicubic
backend: pil
- CropImage:
size: 384
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Eval:
- TopkAcc:
topk: [1, 5]
===========================train_params===========================
model_name:NextViT_base_224
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:NextViT_base_384
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,384,384]}]
===========================train_params===========================
model_name:NextViT_large_224
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:NextViT_large_384
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,384,384]}]
===========================train_params===========================
model_name:NextViT_small_224
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
===========================train_params===========================
model_name:NextViT_small_384
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,384,384]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册