{ "cells": [ { "cell_type": "markdown", "id": "ae69ce68", "metadata": {}, "source": [ "## 1. PLSC-SwinTransformer模型简介\n" ] }, { "cell_type": "markdown", "id": "35485bc6", "metadata": {}, "source": [ "PLSC-SwinTransformer实现了基于[Swin Transformer](https://github.com/microsoft/Swin-Transformer)的视觉分类模型。Swin Transformer是一个层级结构的Vision Transformer(ViT),Swin代表的是滑动窗口。与ViT不同,Swin基于非重叠的局部窗口计算自注意力,并且跨窗口进行连接保证窗口间信息共享,因此Swin Transormer相比于基于全局的ViT更高效。Swin Transformer可以作为CV领域的一个通用的backbone。模型结构如下,\n", "\n", "![Figure 1 from paper](https://github.com/microsoft/Swin-Transformer/blob/main/figures/teaser.png?raw=true)\n" ] }, { "cell_type": "markdown", "id": "97e174e6", "metadata": {}, "source": [ "## 2. 模型效果 " ] }, { "cell_type": "markdown", "id": "78137a72", "metadata": {}, "source": [ "| Model |DType | Phase | Dataset | gpu | img/sec | Top1 Acc | Official |\n", "| --- | --- | --- | --- | --- | --- | --- | --- |\n", "| Swin-B |FP16 O1|pretrain |ImageNet2012 |A100*N1C8 | 2155| 0.83362 | 0.835 |\n", "| Swin-B |FP16 O2|pretrain | ImageNet2012 | A100*N1C8 | 3006 | 0.83223\t | 0.835 |\n" ] }, { "cell_type": "markdown", "id": "ace3c48d", "metadata": {}, "source": [ "## 3. 模型如何使用" ] }, { "cell_type": "markdown", "id": "a97a5f56", "metadata": {}, "source": [ "### 3.1 安装PLSC" ] }, { "cell_type": "markdown", "id": "492fa769-2fe0-4220-b6d9-bbc32f8cca10", "metadata": {}, "source": [ "```\n", "git clone https://github.com/PaddlePaddle/PLSC.git\n", "cd /path/to/PLSC/\n", "# [optional] pip install -r requirements.txt\n", "python setup.py develop\n", "```" ] }, { "cell_type": "markdown", "id": "6b22824d", "metadata": {}, "source": [ "### 3.2 模型训练" ] }, { "cell_type": "markdown", "id": "d68ca5fb", "metadata": {}, "source": [ "1. 进入任务目录\n", "\n", "```\n", "cd task/classification/swin\n", "```" ] }, { "cell_type": "markdown", "id": "9048df01", "metadata": {}, "source": [ "2. 准备数据\n", "\n", "将数据整理成以下格式:\n", "```text\n", "dataset/\n", "└── ILSVRC2012\n", " ├── train\n", " ├── val\n", " ├── train_list.txt\n", " └── val_list.txt\n", "```" ] }, { "cell_type": "markdown", "id": "bea743ea", "metadata": {}, "source": [ "3. 执行训练命令\n", "\n", "```shell\n", "export PADDLE_NNODES=1\n", "export PADDLE_MASTER=\"\"\n", "export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\n", "\n", "python -m paddle.distributed.launch \\\n", " --nnodes=$PADDLE_NNODES \\\n", " --master=$PADDLE_MASTER \\\n", " --devices=$CUDA_VISIBLE_DEVICES \\\n", " plsc-train \\\n", " -c ./configs/swin_base_patch4_window7_224_in1k_1n8c_dp_fp16o1.yaml\n", "```\n", "\n", "更多模型的训练教程可参考文档:[Swin训练文档](https://github.com/PaddlePaddle/PLSC/blob/master/task/classification/swin/README.md)" ] }, { "cell_type": "markdown", "id": "186a0c17", "metadata": {}, "source": [ "### 3.3 模型推理" ] }, { "cell_type": "markdown", "id": "e97c527c", "metadata": {}, "source": [ "1. 下载预训练模型和图片\n", "\n", "```shell\n", "# download pretrained model\n", "mkdir -p pretrained/swin/Swin_base/\n", "wget -O ./pretrained/swin/Swin_base/swin_base_patch4_window7_224_fp16o1.pdparams \n", "https://plsc.bj.bcebos.com/models/swin/v2.5/swin_base_patch4_window7_224_fp16o1.pdparams\n", "\n", "# download image\n", "mkdir -p images/\n", "wget -O ./images/zebra.png https://plsc.bj.bcebos.com/dataset/test_images/zebra.png \n", "```" ] }, { "cell_type": "markdown", "id": "a07c6549", "metadata": {}, "source": [ "2. 导出推理模型\n", "\n", "```shell\n", "plsc-export -c ./configs/swin_base_patch4_window7_224_in1k_1n8c_dp_fp16o1.yaml -o Global.pretrained_model=./pretrained/swin/Swin_base/swin_base_patch4_window7_224_fp16o1 -o Model.data_format=NCHW -o FP16.level=O0\n", "```\n" ] }, { "cell_type": "markdown", "id": "3ded8e73-3dba-49ce-bfb3-fcf7f3f0fc1d", "metadata": {}, "source": [ "3. 图片预测" ] }, { "cell_type": "code", "execution_count": null, "id": "9533d4df-acb3-474f-b591-f210639a0a02", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from plsc.data.dataset import default_loader\n", "from plsc.data.preprocess import Resize\n", "from plsc.engine.inference import Predictor\n", "\n", "\n", "def preprocess(img):\n", " resize = Resize(size=224, \n", " interpolation=\"bicubic\", \n", " backend=\"pil\")\n", " img = np.array(resize(img))\n", " scale = 1.0 / 255.0\n", " mean = np.array([0.485, 0.456, 0.406])\n", " std = np.array([0.229, 0.224, 0.225])\n", " img = (img * scale - mean) / std\n", " img = img[np.newaxis, :, :, :]\n", " img = img.transpose((0, 3, 1, 2))\n", " return {'x': img.astype('float32')}\n", "\n", "\n", "def postprocess(logits):\n", " \n", " def softmax(x, epsilon=1e-6):\n", " exp_x = np.exp(x)\n", " sfm = (exp_x + epsilon) / (np.sum(exp_x) + epsilon)\n", " return sfm\n", "\n", " pred = np.array(logits).squeeze()\n", " pred = softmax(pred)\n", " pred_class_idx = pred.argsort()[::-1][0]\n", " return pred_class_idx, pred[pred_class_idx]\n", "\n", "\n", "infer_model = \"./output/swin_base_patch4_window7_224/swin_base_patch4_window7_224.pdmodel\"\n", "infer_params = \"./output/swin_base_patch4_window7_224/swin_base_patch4_window7_224.pdiparams\"\n", "\n", "predictor = Predictor(\n", " model_file=infer_model,\n", " params_file=infer_params,\n", " preprocess_fn=preprocess,\n", " postprocess_fn=postprocess)\n", "\n", "image = default_loader(\"./images/zebra.png \")\n", "pred_class_idx, pred_score = predictor.predict(image)" ] }, { "cell_type": "markdown", "id": "d375934d", "metadata": {}, "source": [ "## 4. 相关论文及引用信息\n" ] }, { "cell_type": "markdown", "id": "29f05b07-d323-45e4-b00d-0728eafb5af7", "metadata": {}, "source": [ "```text\n", "@inproceedings{liu2021Swin,\n", " title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},\n", " author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},\n", " booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},\n", " year={2021}\n", "}\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }