From ec2029b821f7b87742e2c1a668aaf4db0dcc2036 Mon Sep 17 00:00:00 2001 From: cuicheng01 Date: Tue, 22 Nov 2022 08:48:40 +0000 Subject: [PATCH] Add NextViT code and docs --- docs/zh_CN/models/ImageNet1k/NextViT.md | 114 ++++ docs/zh_CN/models/ImageNet1k/README.md | 23 + ppcls/arch/backbone/__init__.py | 1 + ppcls/arch/backbone/model_zoo/nextvit.py | 643 ++++++++++++++++++ .../ImageNet/NextViT/NextViT_base_224.yaml | 169 +++++ .../ImageNet/NextViT/NextViT_base_384.yaml | 169 +++++ .../ImageNet/NextViT/NextViT_large_224.yaml | 169 +++++ .../ImageNet/NextViT/NextViT_large_384.yaml | 169 +++++ .../ImageNet/NextViT/NextViT_small_224.yaml | 169 +++++ .../ImageNet/NextViT/NextViT_small_384.yaml | 169 +++++ .../NextViT_base_224_train_infer_python.txt | 54 ++ .../NextViT_base_384_train_infer_python.txt | 54 ++ .../NextViT_large_224_train_infer_python.txt | 54 ++ .../NextViT_large_384_train_infer_python.txt | 54 ++ .../NextViT_small_224_train_infer_python.txt | 54 ++ .../NextViT_small_384_train_infer_python.txt | 54 ++ 16 files changed, 2119 insertions(+) create mode 100644 docs/zh_CN/models/ImageNet1k/NextViT.md create mode 100644 ppcls/arch/backbone/model_zoo/nextvit.py create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml create mode 100644 ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml create mode 100644 test_tipc/configs/NextViT/NextViT_base_224_train_infer_python.txt create mode 100644 test_tipc/configs/NextViT/NextViT_base_384_train_infer_python.txt create mode 100644 test_tipc/configs/NextViT/NextViT_large_224_train_infer_python.txt create mode 100644 test_tipc/configs/NextViT/NextViT_large_384_train_infer_python.txt create mode 100644 test_tipc/configs/NextViT/NextViT_small_224_train_infer_python.txt create mode 100644 test_tipc/configs/NextViT/NextViT_small_384_train_infer_python.txt diff --git a/docs/zh_CN/models/ImageNet1k/NextViT.md b/docs/zh_CN/models/ImageNet1k/NextViT.md new file mode 100644 index 00000000..4ce8b99e --- /dev/null +++ b/docs/zh_CN/models/ImageNet1k/NextViT.md @@ -0,0 +1,114 @@ +# NextViT +----- + +## 目录 + +- [1. 模型介绍](#1) + - [1.1 模型简介](#1.1) + - [1.2 模型指标](#1.2) +- [2. 模型快速体验](#2) +- [3. 模型训练、评估和预测](#3) +- [4. 模型推理部署](#4) + - [4.1 推理模型准备](#4.1) + - [4.2 基于 Python 预测引擎推理](#4.2) + - [4.3 基于 C++ 预测引擎推理](#4.3) + - [4.4 服务化部署](#4.4) + - [4.5 端侧部署](#4.5) + - [4.6 Paddle2ONNX 模型转换与预测](#4.6) + + + +## 1. 模型介绍 + + + +### 1.1 模型简介 + +NextViT 是一种新的视觉 Transformer 网络,可以用作计算机视觉领域的通用骨干网络。作者提出了在现实工业场景中有效部署的 Next generation Vision Transformer,即 Next-ViT,从延迟/准确性权衡的角度来看,它在 CNN 和 ViT 中均占主导地位。在这项工作中,作者分别开发了Next Convolution Block(NCB)和Next Transformer Block(NTB),以通过部署友好的机制捕获局部和全局信息。在此基础上,Next Hybrid Strategy (NHS) 旨在以高效的混合范式堆叠 NCB 和 NTB,从而提高各种下游任务的性能。 +最终,NextViT 在多项任务中达到SOTA效果。[论文地址](https://arxiv.org/pdf/2207.05501.pdf)。 + + + +### 1.2 模型指标 + +| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPs
(G) | Params
(M) | +|:--:|:--:|:--:|:--:|:--:|:--:|:--:| +| NextViT_small_224 | 0.8248 | 0.9616 | 0.825 | - | 5.79 | 31.80 | +| NextViT_base_224 | 0.8324 | 0.9658 | 0.832 | - | 8.26 | 44.88 | +| NextViT_large_224 | 0.8363 | 0.9661 | 0.836 | - | 10.73 | 57.95 | +| NextViT_small_384 | 0.8401 | 0.9698 | 0.836 | - | 17.00 | 31.80 | +| NextViT_base_384 | 0.8465 | 0.9723 | 0.843 | - |24.27 | 44.88 | +| NextViT_large_384 | 0.8492 | 0.9728 | 0.847 | - | 31.53 | 57.95 | +| NextViT_small_224_ssld | 0.8472 | 0.9734 | 0.848 | - | 5.79 | 31.80 | +| NextViT_base_224_ssld | 0.8500 | 0.9753 | 0.851 | - | 8.26 | 44.88 | +| NextViT_large_224_ssld | 0.8536 | 0.9762 | 0.854 | - | 10.73 | 57.95 | +| NextViT_small_384_ssld | 0.8597 | 0.9790 | 0.858 | - | 17.00 | 31.80 | +| NextViT_base_384_ssld | 0.8634 | 0.9806 | 0.861 | - |24.27 | 44.88 | +| NextViT_large_384_ssld | 0.8654 | 0.9814 | 0.864 | - | 31.53 | 57.95 | + +**备注:** +- PaddleClas 所提供的该系列模型的预训练模型权重,均是基于其官方提供的权重转得。PaddleClas 验证了 NextViT_small_224 的精度可以与论文精度对齐。 +- 此处 `_ssld` 并非使用 PaddleClas 中的蒸馏的`SSLD 蒸馏`方法得到,而是使用类似`SSLD 蒸馏`挖掘的数据集训练得到。 + + + + +## 2. 模型快速体验 + +安装 paddlepaddle 和 paddleclas 即可快速对图片进行预测,体验方法可以参考[ResNet50 模型快速体验](./ResNet.md#2-模型快速体验)。 + + + +## 3. 模型训练、评估和预测 + +此部分内容包括训练环境配置、ImageNet数据的准备、该模型在 ImageNet 上的训练、评估、预测等内容。在 `ppcls/configs/ImageNet/NextViT/` 中提供了该模型的训练配置,启动训练方法可以参考:[ResNet50 模型训练、评估和预测](./ResNet.md#3-模型训练评估和预测)。 + +**备注:** 由于 NextViT 系列模型默认使用的 GPU 数量为 8 个,所以在训练时,需要指定8个GPU,如`python3 -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c xxx.yaml`, 如果使用 4 个 GPU 训练,默认学习率需要减小一半,精度可能有损。 + + + +## 4. 模型推理部署 + + + +### 4.1 推理模型准备 + +Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用 MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍,可以参考[Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)。 + +Inference 的获取可以参考 [ResNet50 推理模型准备](./ResNet.md#41-推理模型准备) 。 + + + +### 4.2 基于 Python 预测引擎推理 + +PaddleClas 提供了基于 python 预测引擎推理的示例。您可以参考[ResNet50 基于 Python 预测引擎推理](./ResNet.md#42-基于-python-预测引擎推理) 。 + + + +### 4.3 基于 C++ 预测引擎推理 + +PaddleClas 提供了基于 C++ 预测引擎推理的示例,您可以参考[服务器端 C++ 预测](../../deployment/image_classification/cpp/linux.md)来完成相应的推理部署。如果您使用的是 Windows 平台,可以参考[基于 Visual Studio 2019 Community CMake 编译指南](../../deployment/image_classification/cpp/windows.md)完成相应的预测库编译和模型预测工作。 + + + +### 4.4 服务化部署 + +Paddle Serving 提供高性能、灵活易用的工业级在线推理服务。Paddle Serving 支持 RESTful、gRPC、bRPC 等多种协议,提供多种异构硬件和多种操作系统环境下推理解决方案。更多关于Paddle Serving 的介绍,可以参考[Paddle Serving 代码仓库](https://github.com/PaddlePaddle/Serving)。 + +PaddleClas 提供了基于 Paddle Serving 来完成模型服务化部署的示例,您可以参考[模型服务化部署](../../deployment/image_classification/paddle_serving.md)来完成相应的部署工作。 + + + +### 4.5 端侧部署 + +Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深度学习推理框架,定位于支持包括移动端、嵌入式以及服务器端在内的多硬件平台。更多关于 Paddle Lite 的介绍,可以参考[Paddle Lite 代码仓库](https://github.com/PaddlePaddle/Paddle-Lite)。 + +PaddleClas 提供了基于 Paddle Lite 来完成模型端侧部署的示例,您可以参考[端侧部署](../../deployment/image_classification/paddle_lite.md)来完成相应的部署工作。 + + + +### 4.6 Paddle2ONNX 模型转换与预测 + +Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考[Paddle2ONNX 代码仓库](https://github.com/PaddlePaddle/Paddle2ONNX)。 + +PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考[Paddle2ONNX 模型转换与预测](../../deployment/image_classification/paddle2onnx.md)来完成相应的部署工作。 diff --git a/docs/zh_CN/models/ImageNet1k/README.md b/docs/zh_CN/models/ImageNet1k/README.md index 8e98a6f7..229909ea 100644 --- a/docs/zh_CN/models/ImageNet1k/README.md +++ b/docs/zh_CN/models/ImageNet1k/README.md @@ -49,6 +49,7 @@ - [PVTV2 系列](#PVTV2) - [LeViT 系列](#LeViT) - [TNT 系列](#TNT) + - [NextViT 系列](#NextViT) - [4.2 轻量级模型](#Transformer_lite) - [MobileViT 系列](#MobileViT) - [五、参考文献](#reference) @@ -701,6 +702,26 @@ DeiT(Data-efficient Image Transformers)系列模型的精度、速度指标 **注**:TNT 模型的数据预处理部分 `NormalizeImage` 中的 `mean` 与 `std` 均为 0.5。 + +## NextViT 系列 [[35](#ref47)] + +关于 NextViT 系列模型的精度、速度指标如下表所示,更多介绍可以参考:[NextViT 系列模型文档](NextViT.md)。 + +| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
bs=1 | time(ms)
bs=4 | FLOPs(G) | Params(M) |预训练模型下载地址 | inference模型下载地址 | +| ---------- | --------- | --------- | ---------------- | ---------------- | -------- | --------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| NextViT_small_224 | 0.8248 | 0.9616 | - | - | 5.79 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_224_infer.tar) | +| NextViT_base_224 | 0.8324 | 0.9658 | - | - | 8.26 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_224_infer.tar) | +| NextViT_large_224 | 0.8363 | 0.9661 | - | - | 10.73 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_224_infer.tar) | +| NextViT_small_384 | 0.8401 | 0.9698 | - | - | 17.00 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_384_infer.tar) | +| NextViT_base_384 | 0.8465 | 0.9723 | - | - |24.27 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_384_infer.tar) | +| NextViT_large_384 | 0.8492 | 0.9728 | - | - | 31.53 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_384_infer.tar) | +| NextViT_small_224_ssld | 0.8472 | 0.9734 | - | - | 5.79 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_224_ssld_infer.tar) | +| NextViT_base_224_ssld | 0.8500 | 0.9753 | - | - | 8.26 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_224_ssld_infer.tar) | +| NextViT_large_224_ssld | 0.8536 | 0.9762 | - | - | 10.73 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_224_ssld_infer.tar) | +| NextViT_small_384_ssld | 0.8597 | 0.9790 | - | - | 17.00 | 31.80 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_small_384_ssld_infer.tar) | +| NextViT_base_384_ssld | 0.8634 | 0.9806 | - | - |24.27 | 44.88 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_base_384_ssld_infer.tar) | +| NextViT_large_384_ssld | 0.8654 | 0.9814 | - | - | 31.53 | 57.95 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_ssld_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/NextViT_large_384_ssld_infer.tar) | + ### 4.2 轻量级模型 @@ -813,3 +834,5 @@ TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE. [45]Robert J. Wang, Xiang Li, Charles X. Ling. Pelee: A Real-Time Object Detection System on Mobile Devices [46]Chien-Yao Wang, Hong-Yuan Mark Liao, I-Hau Yeh, Yueh-Hua Wu, Ping-Yang Chen, Jun-Wei Hsieh. CSPNet: A New Backbone that can Enhance Learning Capability of CNN + +[46]Jiashi Li, Xin Xia, Wei Li, Huixia Li, Xing Wang, Xuefeng Xiao, Rui Wang, Min Zheng, Xin Pan. Next-ViT: Next Generation Vision Transformer for Efficient Deployment in Realistic Industrial Scenarios. diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 0660dcf1..5220ea16 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG from .model_zoo.van import VAN_B0 from .model_zoo.peleenet import PeleeNet from .model_zoo.convnext import ConvNeXt_tiny +from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_large_224, NextViT_small_384, NextViT_base_384, NextViT_large_384 from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 from .variant_models.resnet_variant import ResNet50_last_stage_stride1 diff --git a/ppcls/arch/backbone/model_zoo/nextvit.py b/ppcls/arch/backbone/model_zoo/nextvit.py new file mode 100644 index 00000000..c4a5972f --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/nextvit.py @@ -0,0 +1,643 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/bytedance/Next-ViT/blob/main/classification/nextvit.py +# reference: https://arxiv.org/abs/2207.05501 + +from functools import partial + +import paddle +from paddle import nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity + +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "NextViT_small_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams", + "NextViT_base_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams", + "NextViT_large_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams", + "NextViT_small_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams", + "NextViT_base_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams", + "NextViT_large_384": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +NORM_EPS = 1e-5 + + +def rearrange(x, pattern, **axes_lengths): + if 'b (h w) c -> b c h w' == pattern: + b, n, c = x.shape + h = axes_lengths.pop('h', -1) + w = axes_lengths.pop('w', -1) + h = h if w == -1 else n // w + w = w if h == -1 else n // h + return x.transpose([0, 2, 1]).reshape([b, c, h, w]) + if 'b c h w -> b (h w) c' == pattern: + b, c, h, w = x.shape + return x.reshape([b, c, h * w]).transpose([0, 2, 1]) + if 'b t (h d) -> b h t d' == pattern: + b, t, h_d = x.shape + h = axes_lengths['h'] + return x.reshape([b, t, h, h_d // h]).transpose([0, 2, 1, 3]) + if 'b h t d -> b t (h d)' == pattern: + b, h, t, d = x.shape + return x.transpose([0, 2, 1, 3]).reshape([b, t, h * d]) + + raise NotImplementedError( + "Rearrangement '{}' has not been implemented.".format(pattern)) + + +def merge_pre_bn(layer, pre_bn_1, pre_bn_2=None): + """ Merge pre BN to reduce inference runtime. + """ + weight = layer.weight + if isinstance(layer, nn.Linear): + weight = weight.transpose([1, 0]) + bias = layer.bias + if pre_bn_2 is None: + scale_invstd = (pre_bn_1._variance + pre_bn_1._epsilon).pow(-0.5) + extra_weight = scale_invstd * pre_bn_1.weight + extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1._mean * scale_invstd + else: + scale_invstd_1 = (pre_bn_1._variance + pre_bn_1._epsilon).pow(-0.5) + scale_invstd_2 = (pre_bn_2._variance + pre_bn_2._epsilon).pow(-0.5) + + extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight + extra_bias = scale_invstd_2 * pre_bn_2.weight * ( + pre_bn_1.bias - pre_bn_1.weight * pre_bn_1._mean * scale_invstd_1 - + pre_bn_2._mean) + pre_bn_2.bias + if isinstance(layer, nn.Linear): + extra_bias = weight @extra_bias + + weight = weight.multiply( + extra_weight.reshape([1, weight.shape[1]]).expand_as(weight)) + weight = weight.transpose([1, 0]) + elif isinstance(layer, nn.Conv2D): + assert weight.shape[2] == 1 and weight.shape[3] == 1 + + weight = weight.reshape([weight.shape[0], weight.shape[1]]) + extra_bias = weight @extra_bias + weight = weight.multiply( + extra_weight.reshape([1, weight.shape[1]]).expand_as(weight)) + weight = weight.reshape([weight.shape[0], weight.shape[1], 1, 1]) + bias = bias.add(extra_bias) + + layer.weight.set_value(weight) + layer.bias.set_value(bias) + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=1, + groups=groups, + bias_attr=False) + self.norm = nn.BatchNorm2D(out_channels, epsilon=NORM_EPS) + self.act = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +class PatchEmbed(nn.Layer): + def __init__(self, in_channels, out_channels, stride=1): + super(PatchEmbed, self).__init__() + norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS) + if stride == 2: + self.avgpool = nn.AvgPool2D((2, 2), stride=2, ceil_mode=True) + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias_attr=False) + self.norm = norm_layer(out_channels) + elif in_channels != out_channels: + self.avgpool = nn.Identity() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias_attr=False) + self.norm = norm_layer(out_channels) + else: + self.avgpool = nn.Identity() + self.conv = nn.Identity() + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.conv(self.avgpool(x))) + + +class MHCA(nn.Layer): + """ + Multi-Head Convolutional Attention + """ + + def __init__(self, out_channels, head_dim): + super(MHCA, self).__init__() + norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS) + self.group_conv3x3 = nn.Conv2D( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + groups=out_channels // head_dim, + bias_attr=False) + self.norm = norm_layer(out_channels) + self.act = nn.ReLU() + self.projection = nn.Conv2D( + out_channels, out_channels, kernel_size=1, bias_attr=False) + + def forward(self, x): + out = self.group_conv3x3(x) + out = self.norm(out) + out = self.act(out) + out = self.projection(out) + return out + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + out_features=None, + mlp_ratio=None, + drop=0., + bias=True): + super().__init__() + out_features = out_features or in_features + hidden_dim = _make_divisible(in_features * mlp_ratio, 32) + self.conv1 = nn.Conv2D( + in_features, + hidden_dim, + kernel_size=1, + bias_attr=None if bias == True else False) + self.act = nn.ReLU() + self.conv2 = nn.Conv2D( + hidden_dim, + out_features, + kernel_size=1, + bias_attr=None if bias == True else False) + self.drop = nn.Dropout(drop) + + def merge_bn(self, pre_norm): + merge_pre_bn(self.conv1, pre_norm) + self.is_bn_merged = True + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.drop(x) + x = self.conv2(x) + x = self.drop(x) + return x + + +class NCB(nn.Layer): + """ + Next Convolution Block + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + path_dropout=0.0, + drop=0.0, + head_dim=32, + mlp_ratio=3): + super(NCB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + norm_layer = partial(nn.BatchNorm2D, epsilon=NORM_EPS) + assert out_channels % head_dim == 0 + + self.patch_embed = PatchEmbed(in_channels, out_channels, stride) + self.mhca = MHCA(out_channels, head_dim) + self.attention_path_dropout = DropPath(path_dropout) + + self.norm = norm_layer(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True) + self.mlp_path_dropout = DropPath(path_dropout) + self.is_bn_merged = False + + def merge_bn(self): + if not self.is_bn_merged: + self.mlp.merge_bn(self.norm) + self.is_bn_merged = True + + def forward(self, x): + x = self.patch_embed(x) + x = x + self.attention_path_dropout(self.mhca(x)) + + if not self.is_bn_merged: + out = self.norm(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +class E_MHSA(nn.Layer): + """ + Efficient Multi-Head Self Attention + """ + + def __init__(self, + dim, + out_dim=None, + head_dim=32, + qkv_bias=True, + qk_scale=None, + attn_drop=0, + proj_drop=0., + sr_ratio=1): + super().__init__() + self.dim = dim + self.out_dim = out_dim if out_dim is not None else dim + self.num_heads = self.dim // head_dim + self.scale = qk_scale or head_dim**-0.5 + self.q = nn.Linear(dim, self.dim, bias_attr=qkv_bias) + self.k = nn.Linear(dim, self.dim, bias_attr=qkv_bias) + self.v = nn.Linear(dim, self.dim, bias_attr=qkv_bias) + self.proj = nn.Linear(self.dim, self.out_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.N_ratio = sr_ratio**2 + if sr_ratio > 1: + self.sr = nn.AvgPool1D( + kernel_size=self.N_ratio, stride=self.N_ratio) + self.norm = nn.BatchNorm1D(dim, epsilon=NORM_EPS) + self.is_bn_merged = False + + def merge_bn(self, pre_bn): + merge_pre_bn(self.q, pre_bn) + if self.sr_ratio > 1: + merge_pre_bn(self.k, pre_bn, self.norm) + merge_pre_bn(self.v, pre_bn, self.norm) + else: + merge_pre_bn(self.k, pre_bn) + merge_pre_bn(self.v, pre_bn) + self.is_bn_merged = True + + def forward(self, x): + B, N, C = x.shape + q = self.q(x) + q = q.reshape( + [B, N, self.num_heads, int(C // self.num_heads)]).transpose( + [0, 2, 1, 3]) + if self.sr_ratio > 1: + x_ = x.transpose([0, 2, 1]) + x_ = self.sr(x_) + if not self.is_bn_merged: + x_ = self.norm(x_) + x_ = x_.transpose([0, 2, 1]) + + k = self.k(x_) + k = k.reshape( + [B, k.shape[1], self.num_heads, int(C // self.num_heads) + ]).transpose([0, 2, 3, 1]) + v = self.v(x_) + v = v.reshape( + [B, v.shape[1], self.num_heads, int(C // self.num_heads) + ]).transpose([0, 2, 1, 3]) + else: + k = self.k(x) + k = k.reshape( + [B, k.shape[1], self.num_heads, int(C // self.num_heads) + ]).transpose([0, 2, 3, 1]) + v = self.v(x) + v = v.reshape( + [B, v.shape[1], self.num_heads, int(C // self.num_heads) + ]).transpose([0, 2, 1, 3]) + attn = (q @k) * self.scale + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, C]) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class NTB(nn.Layer): + """ + Next Transformer Block + """ + + def __init__( + self, + in_channels, + out_channels, + path_dropout, + stride=1, + sr_ratio=1, + mlp_ratio=2, + head_dim=32, + mix_block_ratio=0.75, + attn_drop=0.0, + drop=0.0, ): + super(NTB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mix_block_ratio = mix_block_ratio + norm_func = partial(nn.BatchNorm2D, epsilon=NORM_EPS) + + self.mhsa_out_channels = _make_divisible( + int(out_channels * mix_block_ratio), 32) + self.mhca_out_channels = out_channels - self.mhsa_out_channels + + self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, + stride) + self.norm1 = norm_func(self.mhsa_out_channels) + self.e_mhsa = E_MHSA( + self.mhsa_out_channels, + head_dim=head_dim, + sr_ratio=sr_ratio, + attn_drop=attn_drop, + proj_drop=drop) + self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio) + + self.projection = PatchEmbed( + self.mhsa_out_channels, self.mhca_out_channels, stride=1) + self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim) + self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio)) + + self.norm2 = norm_func(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop) + self.mlp_path_dropout = DropPath(path_dropout) + + self.is_bn_merged = False + + def merge_bn(self): + if not self.is_bn_merged: + self.e_mhsa.merge_bn(self.norm1) + self.mlp.merge_bn(self.norm2) + self.is_bn_merged = True + + def forward(self, x): + x = self.patch_embed(x) + + B, C, H, W = x.shape + if not self.is_bn_merged: + out = self.norm1(x) + else: + out = x + out = rearrange(out, "b c h w -> b (h w) c") # b n c + out = self.e_mhsa(out) + out = self.mhsa_path_dropout(out) + x = x + rearrange(out, "b (h w) c -> b c h w", h=H) + + out = self.projection(x) + out = out + self.mhca_path_dropout(self.mhca(out)) + x = paddle.concat([x, out], axis=1) + + if not self.is_bn_merged: + out = self.norm2(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +class NextViT(nn.Layer): + def __init__(self, + stem_chs, + depths, + path_dropout, + attn_drop=0, + drop=0, + class_num=1000, + strides=[1, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + head_dim=32, + mix_block_ratio=0.75): + super(NextViT, self).__init__() + + self.stage_out_channels = [ + [96] * (depths[0]), [192] * (depths[1] - 1) + [256], + [384, 384, 384, 384, 512] * (depths[2] // 5), + [768] * (depths[3] - 1) + [1024] + ] + + # Next Hybrid Strategy + self.stage_block_types = [[NCB] * depths[0], + [NCB] * (depths[1] - 1) + [NTB], + [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5), + [NCB] * (depths[3] - 1) + [NTB]] + + self.stem = nn.Sequential( + ConvBNReLU( + 3, stem_chs[0], kernel_size=3, stride=2), + ConvBNReLU( + stem_chs[0], stem_chs[1], kernel_size=3, stride=1), + ConvBNReLU( + stem_chs[1], stem_chs[2], kernel_size=3, stride=1), + ConvBNReLU( + stem_chs[2], stem_chs[2], kernel_size=3, stride=2), ) + input_channel = stem_chs[-1] + features = [] + idx = 0 + dpr = [ + x.item() for x in paddle.linspace(0, path_dropout, sum(depths)) + ] # stochastic depth decay rule + for stage_id in range(len(depths)): + numrepeat = depths[stage_id] + output_channels = self.stage_out_channels[stage_id] + block_types = self.stage_block_types[stage_id] + for block_id in range(numrepeat): + if strides[stage_id] == 2 and block_id == 0: + stride = 2 + else: + stride = 1 + output_channel = output_channels[block_id] + block_type = block_types[block_id] + if block_type is NCB: + layer = NCB(input_channel, + output_channel, + stride=stride, + path_dropout=dpr[idx + block_id], + drop=drop, + head_dim=head_dim) + features.append(layer) + elif block_type is NTB: + layer = NTB(input_channel, + output_channel, + path_dropout=dpr[idx + block_id], + stride=stride, + sr_ratio=sr_ratios[stage_id], + head_dim=head_dim, + mix_block_ratio=mix_block_ratio, + attn_drop=attn_drop, + drop=drop) + features.append(layer) + input_channel = output_channel + idx += numrepeat + self.features = nn.Sequential(*features) + + self.norm = nn.BatchNorm2D(output_channel, epsilon=NORM_EPS) + + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + self.proj_head = nn.Sequential(nn.Linear(output_channel, class_num), ) + + self.stage_out_idx = [ + sum(depths[:idx + 1]) - 1 for idx in range(len(depths)) + ] + self._initialize_weights() + + def merge_bn(self): + self.eval() + for idx, layer in self.named_sublayers(): + if isinstance(layer, NCB) or isinstance(layer, NTB): + layer.merge_bn() + + def _initialize_weights(self): + for n, m in self.named_sublayers(): + if isinstance(m, (nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm, + nn.BatchNorm1D)): + ones_(m.weight) + zeros_(m.bias) + elif isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.Conv2D): + trunc_normal_(m.weight) + if hasattr(m, 'bias') and m.bias is not None: + zeros_(m.bias) + + def forward(self, x): + x = self.stem(x) + for layer in self.features: + x = layer(x) + x = self.norm(x) + x = self.avgpool(x) + x = paddle.flatten(x, 1) + x = self.proj_head(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def NextViT_small_224(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 10, 3], + path_dropout=0.1, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_small_224"], use_ssld=use_ssld) + return model + + +def NextViT_base_224(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 20, 3], + path_dropout=0.2, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_base_224"], use_ssld=use_ssld) + return model + + +def NextViT_large_224(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 30, 3], + path_dropout=0.2, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_large_224"], use_ssld=use_ssld) + return model + + +def NextViT_small_384(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 10, 3], + path_dropout=0.1, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_small_384"], use_ssld=use_ssld) + return model + + +def NextViT_base_384(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 20, 3], + path_dropout=0.2, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_base_384"], use_ssld=use_ssld) + return model + + +def NextViT_large_384(pretrained=False, use_ssld=False, **kwargs): + model = NextViT( + stem_chs=[64, 32, 64], + depths=[3, 4, 30, 3], + path_dropout=0.2, + **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["NextViT_large_384"], use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml new file mode 100644 index 00000000..688a7d1d --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_base_224 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml new file mode 100644 index 00000000..27d991ce --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_base_384 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 384 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml new file mode 100644 index 00000000..bb54db87 --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_large_224 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml new file mode 100644 index 00000000..0f1ccd4f --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_large_384 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 384 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml new file mode 100644 index 00000000..996d42ae --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_small_224 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml b/ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml new file mode 100644 index 00000000..009dddfa --- /dev/null +++ b/ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml @@ -0,0 +1,169 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 300 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + + +# mixed precision training +AMP: + scale_loss: 128.0 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + + +# model architecture +Arch: + name: NextViT_small_384 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + epsilon: 1e-8 + weight_decay: 0.1 + no_weight_decay_name: .bias norm + one_dim_param_no_weight_decay: True + lr: + # for 8 cards + name: Cosine + learning_rate: 1e-3 + eta_min: 1e-5 + warmup_epoch: 20 + warmup_start_lr: 1e-6 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 384 + interpolation: bicubic + backend: pil + - RandFlipImage: + flip_code: 1 + - TimmAutoAugment: + config_str: rand-m9-mstd0.5-inc1 + interpolation: bicubic + img_size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - RandomErasing: + EPSILON: 0.25 + sl: 0.02 + sh: 1.0/3.0 + r1: 0.3 + attempt: 10 + use_log_aspect: True + mode: pixel + batch_transform_ops: + - OpSampler: + MixupOperator: + alpha: 0.8 + prob: 0.5 + CutmixOperator: + alpha: 1.0 + prob: 0.5 + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: ./deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 384 + interpolation: bicubic + backend: pil + - CropImage: + size: 384 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/test_tipc/configs/NextViT/NextViT_base_224_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_base_224_train_infer_python.txt new file mode 100644 index 00000000..2a9c98b2 --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_base_224_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_base_224 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_224.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_224_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/NextViT/NextViT_base_384_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_base_384_train_infer_python.txt new file mode 100644 index 00000000..d5ecbd2b --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_base_384_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_base_384 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_base_384.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_base_384_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,384,384]}] diff --git a/test_tipc/configs/NextViT/NextViT_large_224_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_large_224_train_infer_python.txt new file mode 100644 index 00000000..64a11b24 --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_large_224_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_large_224 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_224.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_224_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/NextViT/NextViT_large_384_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_large_384_train_infer_python.txt new file mode 100644 index 00000000..67bd5022 --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_large_384_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_large_384 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_large_384.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_large_384_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,384,384]}] diff --git a/test_tipc/configs/NextViT/NextViT_small_224_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_small_224_train_infer_python.txt new file mode 100644 index 00000000..5140bb46 --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_small_224_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_small_224 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_224.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_224_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/NextViT/NextViT_small_384_train_infer_python.txt b/test_tipc/configs/NextViT/NextViT_small_384_train_infer_python.txt new file mode 100644 index 00000000..527ab0e2 --- /dev/null +++ b/test_tipc/configs/NextViT/NextViT_small_384_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:NextViT_small_384 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/NextViT/NextViT_small_384.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/NextViT_small_384_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=384 -o PreProcess.transform_ops.1.CropImage.size=384 +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,384,384]}] -- GitLab