{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# 文本识别实战\n", "\n", "上一章理论部分,介绍了文本识别领域的主要方法,其中CRNN是较早被提出也是目前工业界应用较多的方法。本章将详细介绍如何基于PaddleOCR完成CRNN文本识别模型的搭建、训练、评估和预测。数据集采用 icdar 2015,其中训练集有4468张,测试集有2077张。\n", "\n", "\n", "通过本章的学习,你可以掌握:\n", "\n", "1. 如何使用paddleocr whl 包快速完成文本识别预测\n", "\n", "2. CRNN的基本原理和网络结构\n", "\n", "3. 模型训练的必须步骤和调参方式\n", "\n", "4. 使用自定义的数据集训练网络\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 1. 快速体验\n", "\n", "### 1.1 安装相关的依赖及whl包\n", "\n", "首先确认安装了 paddle 以及 paddleocr,如果已经安装过,忽略该步骤。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Requirement already satisfied: paddlepaddle-gpu in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.1.2.post101)\n", "Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (3.14.0)\n", "Requirement already satisfied: numpy>=1.13; python_version >= \"3.5\" and platform_system != \"Windows\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.20.3)\n", "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (7.1.2)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.15.0)\n", "Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (2.22.0)\n", "Requirement already satisfied: astor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.8.1)\n", "Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (4.4.2)\n", "Requirement already satisfied: gast<=0.4.0,>=0.3.3; platform_system != \"Windows\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.3.3)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (1.25.6)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2019.9.11)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2.8)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (3.0.4)\n", "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Collecting pip\n", "\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)\n", "\u001b[K |████████████████████████████████| 1.7MB 8.4MB/s eta 0:00:01\n", "\u001b[?25hInstalling collected packages: pip\n", " Found existing installation: pip 19.2.3\n", " Uninstalling pip-19.2.3:\n", " Successfully uninstalled pip-19.2.3\n", "Successfully installed pip-21.3.1\n", "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Collecting paddleocr\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/b6/5486e674ce096667dff247b58bf0fb789c2ce17a10e546c2686a2bb07aec/paddleocr-2.3.0.2-py3-none-any.whl (250 kB)\n", " |████████████████████████████████| 250 kB 3.3 MB/s \n", "\u001b[?25hCollecting lmdb\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)\n", " |████████████████████████████████| 299 kB 12.8 MB/s \n", "\u001b[?25hCollecting lxml\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)\n", " |████████████████████████████████| 6.4 MB 52.4 MB/s \n", "\u001b[?25hCollecting python-Levenshtein\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)\n", " |████████████████████████████████| 50 kB 1.6 MB/s \n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting scikit-image\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)\n", " |████████████████████████████████| 13.3 MB 56.1 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (1.20.3)\n", "Collecting imgaug==0.4.0\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)\n", " |████████████████████████████████| 948 kB 62.9 MB/s \n", "\u001b[?25hCollecting opencv-contrib-python==4.4.0.46\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)\n", " |████████████████████████████████| 55.7 MB 44 kB/s 0:01\n", "\u001b[?25hCollecting premailer\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)\n", "Collecting shapely\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", " |████████████████████████████████| 1.1 MB 14.5 MB/s \n", "\u001b[?25hRequirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (2.2.0)\n", "Collecting fasttext==0.9.1\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)\n", " |████████████████████████████████| 57 kB 9.0 MB/s \n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (0.29)\n", "Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (3.0.5)\n", "Collecting pyclipper\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)\n", " |████████████████████████████████| 603 kB 7.6 MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (4.36.1)\n", "Collecting pybind11>=2.2\n", " Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)\n", "Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->paddleocr) (56.2.0)\n", "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (7.1.2)\n", "Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.6.1)\n", "Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.6.3)\n", "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (4.1.1.26)\n", "Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.2.3)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.15.0)\n", "Collecting PyWavelets>=1.1.1\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)\n", " |████████████████████████████████| 6.1 MB 3.8 MB/s \n", "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (20.9)\n", "Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (2.4)\n", "Collecting tifffile>=2019.7.26\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)\n", " |████████████████████████████████| 178 kB 7.1 MB/s \n", "\u001b[?25hRequirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.0.1)\n", "Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.4.1)\n", "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (2.22.0)\n", "Collecting cssselect\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)\n", "Collecting cssutils\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)\n", " |████████████████████████████████| 404 kB 134 kB/s \n", "\u001b[?25hRequirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (4.0.0)\n", "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.21.0)\n", "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.0.0)\n", "Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.1)\n", "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.8.2)\n", "Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.7.1.1)\n", "Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.5)\n", "Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.8.53)\n", "Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.14.0)\n", "Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.2.0)\n", "Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.6.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.23)\n", "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.6.1)\n", "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (1.1.0)\n", "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (2.11.0)\n", "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (7.0)\n", "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (0.16.0)\n", "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2019.3)\n", "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2.8.0)\n", "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->paddleocr) (4.4.2)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->paddleocr) (2.4.2)\n", "Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (3.9.9)\n", "Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (0.18.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (1.1.0)\n", "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (2.8.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (0.10.0)\n", "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.0)\n", "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (16.7.9)\n", "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.4)\n", "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (5.1.2)\n", "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (2.0.1)\n", "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.4.10)\n", "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (0.10.0)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (1.25.6)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (3.0.4)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2019.9.11)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddleocr) (1.1.1)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->paddleocr) (3.6.0)\n", "Building wheels for collected packages: fasttext, python-Levenshtein\n", " Building wheel for fasttext (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2584156 sha256=acb4d4fde73d31c7dfdd2ae3de0da25a558c34c672d4904e6a5c4279185fe5af\n", " Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36\n", " Building wheel for python-Levenshtein (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171687 sha256=56b4a2de4349a05004121050df68b488ffd253dcc59187ca07b89b62d40c0218\n", " Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446\n", "Successfully built fasttext python-Levenshtein\n", "Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext, paddleocr\n", "Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 paddleocr-2.3.0.2 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2\n" ] } ], "source": [ "# 安装 PaddlePaddle GPU 版本\n", "!pip install paddlepaddle-gpu\n", "# 安装 paddleocr whl包\n", "! pip install -U pip\n", "! pip install paddleocr" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 1.2 快速预测文字内容\n", "\n", "paddleocr whl包会自动下载ppocr轻量级模型作为默认模型\n", "\n", "下面展示如何使用whl包进行识别预测:\n", "\n", "测试图片:\n", "\n", "![](https://ai-studio-static-online.cdn.bcebos.com/531d9b3aff45449893b33bcb5dd13971057fcb4038f045578b3abd99fa3a96f2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 20:28:44] root WARNING: version 2.1 not support cls models, use version 2.0 instead\n", "download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar to /home/aistudio/.paddleocr/2.2.1/ocr/det/ch/ch_PP-OCRv2_det_infer/ch_PP-OCRv2_det_infer.tar\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n", "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n", " 0%| | 0.00/3.19M [00:00\n", "\n", "\n", "### 2.2 算法详解\n", "\n", "CRNN 的网络结构体系如下所示,从下往上分别为卷积层、递归层和转录层三部分:\n", "\n", "
\n", "\n", "1)backbone:\n", "\n", "卷积网络作为底层的骨干网络,用于从输入图像中提取特征序列。由于 `conv`、`max-pooling`、`elementwise` 和激活函数都作用在局部区域上,所以它们是平移不变的。因此,特征映射的每一列对应于原始图像的一个矩形区域(称为感受野),并且这些矩形区域与它们在特征映射上对应的列从左到右的顺序相同。由于CNN需要将输入的图像缩放到固定的尺寸以满足其固定的输入维数,因此它不适合长度变化很大的序列对象。为了更好的支持变长序列,CRNN将backbone最后一层输出的特征向量送到了RNN层,转换为序列特征。\n", "\n", "
\n", "\n", "2)neck: \n", "\n", "递归层,在卷积网络的基础上,构建递归网络,将图像特征转换为序列特征,预测每个帧的标签分布。\n", "RNN具有很强的捕获序列上下文信息的能力。使用上下文线索进行基于图像的序列识别比单独处理每个像素更有效。以场景文本识别为例,宽字符可能需要几个连续的帧来充分描述。此外,有些歧义字符在观察其上下文时更容易区分。其次,RNN可以将误差差分反向传播回卷积层,使网络可以统一训练。第三,RNN能够对任意长度的序列进行操作,解决了文本图片变长的问题。CRNN使用双层LSTM作为递归层,解决了长序列训练过程中的梯度消失和梯度爆炸问题。\n", "\n", "
\n", "\n", "\n", "3)head: \n", "\n", "转录层,通过全连接网络和softmax激活函数,将每帧的预测转换为最终的标签序列。最后使用 CTC Loss 在无需序列对齐的情况下,完成CNN和RNN的联合训练。CTC 有一套特别的合并序列机制,LSTM输出序列后,需要在时序上分类得到预测结果。可能存在多个时间步对应同一个类别,因此需要对相同结果进行合并。为避免合并本身存在的重复字符,CTC 引入了一个 `blank` 字符插入在重复字符之间。\n", "\n", "
\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 2.3 代码实现\n", "\n", "整个网络结构非常简洁,代码实现也相对简单,可以跟随预测流程依次搭建模块。本节需要完成:数据输入、backbone搭建、neck搭建、head搭建。\n", "\n", "**【数据输入】**\n", "\n", "数据送入网络前需要缩放到统一尺寸(3,32,320),并完成归一化处理。这里省略掉训练时需要的数据增强部分,以单张图为例展示预处理的必须步骤([源码位置](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/data/imaug/rec_img_aug.py#L126)):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import cv2\n", "import math\n", "import numpy as np\n", "\n", "def resize_norm_img(img):\n", " \"\"\"\n", " 数据缩放和归一化\n", " :param img: 输入图片\n", " \"\"\"\n", "\n", " # 默认输入尺寸\n", " imgC = 3\n", " imgH = 32\n", " imgW = 320\n", "\n", " # 图片的真实高宽\n", " h, w = img.shape[:2]\n", " # 图片真实长宽比\n", " ratio = w / float(h)\n", "\n", " # 按比例缩放\n", " if math.ceil(imgH * ratio) > imgW:\n", " # 如大于默认宽度,则宽度为imgW\n", " resized_w = imgW\n", " else:\n", " # 如小于默认宽度则以图片真实宽为准\n", " resized_w = int(math.ceil(imgH * ratio))\n", " # 缩放\n", " resized_image = cv2.resize(img, (resized_w, imgH))\n", " resized_image = resized_image.astype('float32')\n", " # 归一化\n", " resized_image = resized_image.transpose((2, 0, 1)) / 255\n", " resized_image -= 0.5\n", " resized_image /= 0.5\n", " # 对宽度不足的位置,补0\n", " padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)\n", " padding_im[:, :, 0:resized_w] = resized_image\n", " # 转置 padding 后的图片用于可视化\n", " draw_img = padding_im.transpose((1,2,0))\n", " return padding_im, draw_img\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "# 读图\n", "raw_img = cv2.imread(\"/home/aistudio/work/word_1.png\")\n", "plt.figure()\n", "plt.subplot(2,1,1)\n", "# 可视化原图\n", "plt.imshow(raw_img)\n", "# 缩放并归一化\n", "padding_im, draw_img = resize_norm_img(raw_img)\n", "plt.subplot(2,1,2)\n", "# 可视化网络输入图\n", "plt.imshow(draw_img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "**【网络结构】**\n", "\n", "* backbone\n", "\n", "PaddleOCR 使用 MobileNetV3 作为骨干网络,组网顺序与网络结构一致,首先定义网络中的公共模块([源码位置](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/modeling/backbones/rec_mobilenet_v3.py)):ConvBNLayer、ResidualUnit、make_divisible" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import paddle\n", "import paddle.nn as nn\n", "import paddle.nn.functional as F\n", "\n", "class ConvBNLayer(nn.Layer):\n", " def __init__(self,\n", " in_channels,\n", " out_channels,\n", " kernel_size,\n", " stride,\n", " padding,\n", " groups=1,\n", " if_act=True,\n", " act=None):\n", " \"\"\"\n", " 卷积BN层\n", " :param in_channels: 输入通道数\n", " :param out_channels: 输出通道数\n", " :param kernel_size: 卷积核尺寸\n", " :parma stride: 步长大小\n", " :param padding: 填充大小\n", " :param groups: 二维卷积层的组数\n", " :param if_act: 是否添加激活函数\n", " :param act: 激活函数\n", " \"\"\"\n", " super(ConvBNLayer, self).__init__()\n", " self.if_act = if_act\n", " self.act = act\n", " self.conv = nn.Conv2D(\n", " in_channels=in_channels,\n", " out_channels=out_channels,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", " groups=groups,\n", " bias_attr=False)\n", "\n", " self.bn = nn.BatchNorm(num_channels=out_channels, act=None)\n", "\n", " def forward(self, x):\n", " # conv层\n", " x = self.conv(x)\n", " # batchnorm层\n", " x = self.bn(x)\n", " # 是否使用激活函数\n", " if self.if_act:\n", " if self.act == \"relu\":\n", " x = F.relu(x)\n", " elif self.act == \"hardswish\":\n", " x = F.hardswish(x)\n", " else:\n", " print(\"The activation function({}) is selected incorrectly.\".\n", " format(self.act))\n", " exit()\n", " return x\n", "\n", "class SEModule(nn.Layer):\n", " def __init__(self, in_channels, reduction=4):\n", " \"\"\"\n", " SE模块\n", " :param in_channels: 输入通道数\n", " :param reduction: 通道缩放率\n", " \"\"\" \n", " super(SEModule, self).__init__()\n", " self.avg_pool = nn.AdaptiveAvgPool2D(1)\n", " self.conv1 = nn.Conv2D(\n", " in_channels=in_channels,\n", " out_channels=in_channels // reduction,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0)\n", " self.conv2 = nn.Conv2D(\n", " in_channels=in_channels // reduction,\n", " out_channels=in_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0)\n", "\n", " def forward(self, inputs):\n", " # 平均池化\n", " outputs = self.avg_pool(inputs)\n", " # 第一个卷积层\n", " outputs = self.conv1(outputs)\n", " # relu激活函数\n", " outputs = F.relu(outputs)\n", " # 第二个卷积层\n", " outputs = self.conv2(outputs)\n", " # hardsigmoid 激活函数\n", " outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)\n", " return inputs * outputs\n", "\n", "\n", "class ResidualUnit(nn.Layer):\n", " def __init__(self,\n", " in_channels,\n", " mid_channels,\n", " out_channels,\n", " kernel_size,\n", " stride,\n", " use_se,\n", " act=None):\n", " \"\"\"\n", " 残差层\n", " :param in_channels: 输入通道数\n", " :param mid_channels: 中间通道数\n", " :param out_channels: 输出通道数\n", " :param kernel_size: 卷积核尺寸\n", " :parma stride: 步长大小\n", " :param use_se: 是否使用se模块\n", " :param act: 激活函数\n", " \"\"\" \n", " super(ResidualUnit, self).__init__()\n", " self.if_shortcut = stride == 1 and in_channels == out_channels\n", " self.if_se = use_se\n", "\n", " self.expand_conv = ConvBNLayer(\n", " in_channels=in_channels,\n", " out_channels=mid_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0,\n", " if_act=True,\n", " act=act)\n", " self.bottleneck_conv = ConvBNLayer(\n", " in_channels=mid_channels,\n", " out_channels=mid_channels,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=int((kernel_size - 1) // 2),\n", " groups=mid_channels,\n", " if_act=True,\n", " act=act)\n", " if self.if_se:\n", " self.mid_se = SEModule(mid_channels)\n", " self.linear_conv = ConvBNLayer(\n", " in_channels=mid_channels,\n", " out_channels=out_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0,\n", " if_act=False,\n", " act=None)\n", "\n", " def forward(self, inputs):\n", " x = self.expand_conv(inputs)\n", " x = self.bottleneck_conv(x)\n", " if self.if_se:\n", " x = self.mid_se(x)\n", " x = self.linear_conv(x)\n", " if self.if_shortcut:\n", " x = paddle.add(inputs, x)\n", " return x\n", "\n", "\n", "def make_divisible(v, divisor=8, min_value=None):\n", " \"\"\"\n", " 确保被8整除\n", " \"\"\" \n", " if min_value is None:\n", " min_value = divisor\n", " new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n", " if new_v < 0.9 * v:\n", " new_v += divisor\n", " return new_v\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "利用公共模块搭建骨干网络" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MobileNetV3(nn.Layer):\n", " def __init__(self,\n", " in_channels=3,\n", " model_name='small',\n", " scale=0.5,\n", " small_stride=None,\n", " disable_se=False,\n", " **kwargs):\n", " super(MobileNetV3, self).__init__()\n", " self.disable_se = disable_se\n", " \n", " small_stride = [1, 2, 2, 2]\n", "\n", " if model_name == \"small\":\n", " cfg = [\n", " # k, exp, c, se, nl, s,\n", " [3, 16, 16, True, 'relu', (small_stride[0], 1)],\n", " [3, 72, 24, False, 'relu', (small_stride[1], 1)],\n", " [3, 88, 24, False, 'relu', 1],\n", " [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],\n", " [5, 240, 40, True, 'hardswish', 1],\n", " [5, 240, 40, True, 'hardswish', 1],\n", " [5, 120, 48, True, 'hardswish', 1],\n", " [5, 144, 48, True, 'hardswish', 1],\n", " [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],\n", " [5, 576, 96, True, 'hardswish', 1],\n", " [5, 576, 96, True, 'hardswish', 1],\n", " ]\n", " cls_ch_squeeze = 576\n", " else:\n", " raise NotImplementedError(\"mode[\" + model_name +\n", " \"_model] is not implemented!\")\n", "\n", " supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]\n", " assert scale in supported_scale, \\\n", " \"supported scales are {} but input scale is {}\".format(supported_scale, scale)\n", "\n", " inplanes = 16\n", " # conv1\n", " self.conv1 = ConvBNLayer(\n", " in_channels=in_channels,\n", " out_channels=make_divisible(inplanes * scale),\n", " kernel_size=3,\n", " stride=2,\n", " padding=1,\n", " groups=1,\n", " if_act=True,\n", " act='hardswish')\n", " i = 0\n", " block_list = []\n", " inplanes = make_divisible(inplanes * scale)\n", " for (k, exp, c, se, nl, s) in cfg:\n", " se = se and not self.disable_se\n", " block_list.append(\n", " ResidualUnit(\n", " in_channels=inplanes,\n", " mid_channels=make_divisible(scale * exp),\n", " out_channels=make_divisible(scale * c),\n", " kernel_size=k,\n", " stride=s,\n", " use_se=se,\n", " act=nl))\n", " inplanes = make_divisible(scale * c)\n", " i += 1\n", " self.blocks = nn.Sequential(*block_list)\n", "\n", " self.conv2 = ConvBNLayer(\n", " in_channels=inplanes,\n", " out_channels=make_divisible(scale * cls_ch_squeeze),\n", " kernel_size=1,\n", " stride=1,\n", " padding=0,\n", " groups=1,\n", " if_act=True,\n", " act='hardswish')\n", "\n", " self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)\n", " self.out_channels = make_divisible(scale * cls_ch_squeeze)\n", "\n", " def forward(self, x):\n", " x = self.conv1(x)\n", " x = self.blocks(x)\n", " x = self.conv2(x)\n", " x = self.pool(x)\n", " return x\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "至此就完成了骨干网络的定义,可通过 paddle.summary 结构可视化整个网络结构:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------------------------------------------------------------------------------\n", " Layer (type) Input Shape Output Shape Param # \n", "===============================================================================\n", " Conv2D-1 [[1, 3, 32, 320]] [1, 8, 16, 160] 216 \n", " BatchNorm-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 32 \n", " ConvBNLayer-1 [[1, 3, 32, 320]] [1, 8, 16, 160] 0 \n", " Conv2D-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 64 \n", " BatchNorm-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 32 \n", " ConvBNLayer-2 [[1, 8, 16, 160]] [1, 8, 16, 160] 0 \n", " Conv2D-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 72 \n", " BatchNorm-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 32 \n", " ConvBNLayer-3 [[1, 8, 16, 160]] [1, 8, 16, 160] 0 \n", "AdaptiveAvgPool2D-1 [[1, 8, 16, 160]] [1, 8, 1, 1] 0 \n", " Conv2D-4 [[1, 8, 1, 1]] [1, 2, 1, 1] 18 \n", " Conv2D-5 [[1, 2, 1, 1]] [1, 8, 1, 1] 24 \n", " SEModule-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 0 \n", " Conv2D-6 [[1, 8, 16, 160]] [1, 8, 16, 160] 64 \n", " BatchNorm-4 [[1, 8, 16, 160]] [1, 8, 16, 160] 32 \n", " ConvBNLayer-4 [[1, 8, 16, 160]] [1, 8, 16, 160] 0 \n", " ResidualUnit-1 [[1, 8, 16, 160]] [1, 8, 16, 160] 0 \n", " Conv2D-7 [[1, 8, 16, 160]] [1, 40, 16, 160] 320 \n", " BatchNorm-5 [[1, 40, 16, 160]] [1, 40, 16, 160] 160 \n", " ConvBNLayer-5 [[1, 8, 16, 160]] [1, 40, 16, 160] 0 \n", " Conv2D-8 [[1, 40, 16, 160]] [1, 40, 8, 160] 360 \n", " BatchNorm-6 [[1, 40, 8, 160]] [1, 40, 8, 160] 160 \n", " ConvBNLayer-6 [[1, 40, 16, 160]] [1, 40, 8, 160] 0 \n", " Conv2D-9 [[1, 40, 8, 160]] [1, 16, 8, 160] 640 \n", " BatchNorm-7 [[1, 16, 8, 160]] [1, 16, 8, 160] 64 \n", " ConvBNLayer-7 [[1, 40, 8, 160]] [1, 16, 8, 160] 0 \n", " ResidualUnit-2 [[1, 8, 16, 160]] [1, 16, 8, 160] 0 \n", " Conv2D-10 [[1, 16, 8, 160]] [1, 48, 8, 160] 768 \n", " BatchNorm-8 [[1, 48, 8, 160]] [1, 48, 8, 160] 192 \n", " ConvBNLayer-8 [[1, 16, 8, 160]] [1, 48, 8, 160] 0 \n", " Conv2D-11 [[1, 48, 8, 160]] [1, 48, 8, 160] 432 \n", " BatchNorm-9 [[1, 48, 8, 160]] [1, 48, 8, 160] 192 \n", " ConvBNLayer-9 [[1, 48, 8, 160]] [1, 48, 8, 160] 0 \n", " Conv2D-12 [[1, 48, 8, 160]] [1, 16, 8, 160] 768 \n", " BatchNorm-10 [[1, 16, 8, 160]] [1, 16, 8, 160] 64 \n", " ConvBNLayer-10 [[1, 48, 8, 160]] [1, 16, 8, 160] 0 \n", " ResidualUnit-3 [[1, 16, 8, 160]] [1, 16, 8, 160] 0 \n", " Conv2D-13 [[1, 16, 8, 160]] [1, 48, 8, 160] 768 \n", " BatchNorm-11 [[1, 48, 8, 160]] [1, 48, 8, 160] 192 \n", " ConvBNLayer-11 [[1, 16, 8, 160]] [1, 48, 8, 160] 0 \n", " Conv2D-14 [[1, 48, 8, 160]] [1, 48, 4, 160] 1,200 \n", " BatchNorm-12 [[1, 48, 4, 160]] [1, 48, 4, 160] 192 \n", " ConvBNLayer-12 [[1, 48, 8, 160]] [1, 48, 4, 160] 0 \n", "AdaptiveAvgPool2D-2 [[1, 48, 4, 160]] [1, 48, 1, 1] 0 \n", " Conv2D-15 [[1, 48, 1, 1]] [1, 12, 1, 1] 588 \n", " Conv2D-16 [[1, 12, 1, 1]] [1, 48, 1, 1] 624 \n", " SEModule-2 [[1, 48, 4, 160]] [1, 48, 4, 160] 0 \n", " Conv2D-17 [[1, 48, 4, 160]] [1, 24, 4, 160] 1,152 \n", " BatchNorm-13 [[1, 24, 4, 160]] [1, 24, 4, 160] 96 \n", " ConvBNLayer-13 [[1, 48, 4, 160]] [1, 24, 4, 160] 0 \n", " ResidualUnit-4 [[1, 16, 8, 160]] [1, 24, 4, 160] 0 \n", " Conv2D-18 [[1, 24, 4, 160]] [1, 120, 4, 160] 2,880 \n", " BatchNorm-14 [[1, 120, 4, 160]] [1, 120, 4, 160] 480 \n", " ConvBNLayer-14 [[1, 24, 4, 160]] [1, 120, 4, 160] 0 \n", " Conv2D-19 [[1, 120, 4, 160]] [1, 120, 4, 160] 3,000 \n", " BatchNorm-15 [[1, 120, 4, 160]] [1, 120, 4, 160] 480 \n", " ConvBNLayer-15 [[1, 120, 4, 160]] [1, 120, 4, 160] 0 \n", "AdaptiveAvgPool2D-3 [[1, 120, 4, 160]] [1, 120, 1, 1] 0 \n", " Conv2D-20 [[1, 120, 1, 1]] [1, 30, 1, 1] 3,630 \n", " Conv2D-21 [[1, 30, 1, 1]] [1, 120, 1, 1] 3,720 \n", " SEModule-3 [[1, 120, 4, 160]] [1, 120, 4, 160] 0 \n", " Conv2D-22 [[1, 120, 4, 160]] [1, 24, 4, 160] 2,880 \n", " BatchNorm-16 [[1, 24, 4, 160]] [1, 24, 4, 160] 96 \n", " ConvBNLayer-16 [[1, 120, 4, 160]] [1, 24, 4, 160] 0 \n", " ResidualUnit-5 [[1, 24, 4, 160]] [1, 24, 4, 160] 0 \n", " Conv2D-23 [[1, 24, 4, 160]] [1, 120, 4, 160] 2,880 \n", " BatchNorm-17 [[1, 120, 4, 160]] [1, 120, 4, 160] 480 \n", " ConvBNLayer-17 [[1, 24, 4, 160]] [1, 120, 4, 160] 0 \n", " Conv2D-24 [[1, 120, 4, 160]] [1, 120, 4, 160] 3,000 \n", " BatchNorm-18 [[1, 120, 4, 160]] [1, 120, 4, 160] 480 \n", " ConvBNLayer-18 [[1, 120, 4, 160]] [1, 120, 4, 160] 0 \n", "AdaptiveAvgPool2D-4 [[1, 120, 4, 160]] [1, 120, 1, 1] 0 \n", " Conv2D-25 [[1, 120, 1, 1]] [1, 30, 1, 1] 3,630 \n", " Conv2D-26 [[1, 30, 1, 1]] [1, 120, 1, 1] 3,720 \n", " SEModule-4 [[1, 120, 4, 160]] [1, 120, 4, 160] 0 \n", " Conv2D-27 [[1, 120, 4, 160]] [1, 24, 4, 160] 2,880 \n", " BatchNorm-19 [[1, 24, 4, 160]] [1, 24, 4, 160] 96 \n", " ConvBNLayer-19 [[1, 120, 4, 160]] [1, 24, 4, 160] 0 \n", " ResidualUnit-6 [[1, 24, 4, 160]] [1, 24, 4, 160] 0 \n", " Conv2D-28 [[1, 24, 4, 160]] [1, 64, 4, 160] 1,536 \n", " BatchNorm-20 [[1, 64, 4, 160]] [1, 64, 4, 160] 256 \n", " ConvBNLayer-20 [[1, 24, 4, 160]] [1, 64, 4, 160] 0 \n", " Conv2D-29 [[1, 64, 4, 160]] [1, 64, 4, 160] 1,600 \n", " BatchNorm-21 [[1, 64, 4, 160]] [1, 64, 4, 160] 256 \n", " ConvBNLayer-21 [[1, 64, 4, 160]] [1, 64, 4, 160] 0 \n", "AdaptiveAvgPool2D-5 [[1, 64, 4, 160]] [1, 64, 1, 1] 0 \n", " Conv2D-30 [[1, 64, 1, 1]] [1, 16, 1, 1] 1,040 \n", " Conv2D-31 [[1, 16, 1, 1]] [1, 64, 1, 1] 1,088 \n", " SEModule-5 [[1, 64, 4, 160]] [1, 64, 4, 160] 0 \n", " Conv2D-32 [[1, 64, 4, 160]] [1, 24, 4, 160] 1,536 \n", " BatchNorm-22 [[1, 24, 4, 160]] [1, 24, 4, 160] 96 \n", " ConvBNLayer-22 [[1, 64, 4, 160]] [1, 24, 4, 160] 0 \n", " ResidualUnit-7 [[1, 24, 4, 160]] [1, 24, 4, 160] 0 \n", " Conv2D-33 [[1, 24, 4, 160]] [1, 72, 4, 160] 1,728 \n", " BatchNorm-23 [[1, 72, 4, 160]] [1, 72, 4, 160] 288 \n", " ConvBNLayer-23 [[1, 24, 4, 160]] [1, 72, 4, 160] 0 \n", " Conv2D-34 [[1, 72, 4, 160]] [1, 72, 4, 160] 1,800 \n", " BatchNorm-24 [[1, 72, 4, 160]] [1, 72, 4, 160] 288 \n", " ConvBNLayer-24 [[1, 72, 4, 160]] [1, 72, 4, 160] 0 \n", "AdaptiveAvgPool2D-6 [[1, 72, 4, 160]] [1, 72, 1, 1] 0 \n", " Conv2D-35 [[1, 72, 1, 1]] [1, 18, 1, 1] 1,314 \n", " Conv2D-36 [[1, 18, 1, 1]] [1, 72, 1, 1] 1,368 \n", " SEModule-6 [[1, 72, 4, 160]] [1, 72, 4, 160] 0 \n", " Conv2D-37 [[1, 72, 4, 160]] [1, 24, 4, 160] 1,728 \n", " BatchNorm-25 [[1, 24, 4, 160]] [1, 24, 4, 160] 96 \n", " ConvBNLayer-25 [[1, 72, 4, 160]] [1, 24, 4, 160] 0 \n", " ResidualUnit-8 [[1, 24, 4, 160]] [1, 24, 4, 160] 0 \n", " Conv2D-38 [[1, 24, 4, 160]] [1, 144, 4, 160] 3,456 \n", " BatchNorm-26 [[1, 144, 4, 160]] [1, 144, 4, 160] 576 \n", " ConvBNLayer-26 [[1, 24, 4, 160]] [1, 144, 4, 160] 0 \n", " Conv2D-39 [[1, 144, 4, 160]] [1, 144, 2, 160] 3,600 \n", " BatchNorm-27 [[1, 144, 2, 160]] [1, 144, 2, 160] 576 \n", " ConvBNLayer-27 [[1, 144, 4, 160]] [1, 144, 2, 160] 0 \n", "AdaptiveAvgPool2D-7 [[1, 144, 2, 160]] [1, 144, 1, 1] 0 \n", " Conv2D-40 [[1, 144, 1, 1]] [1, 36, 1, 1] 5,220 \n", " Conv2D-41 [[1, 36, 1, 1]] [1, 144, 1, 1] 5,328 \n", " SEModule-7 [[1, 144, 2, 160]] [1, 144, 2, 160] 0 \n", " Conv2D-42 [[1, 144, 2, 160]] [1, 48, 2, 160] 6,912 \n", " BatchNorm-28 [[1, 48, 2, 160]] [1, 48, 2, 160] 192 \n", " ConvBNLayer-28 [[1, 144, 2, 160]] [1, 48, 2, 160] 0 \n", " ResidualUnit-9 [[1, 24, 4, 160]] [1, 48, 2, 160] 0 \n", " Conv2D-43 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824 \n", " BatchNorm-29 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152 \n", " ConvBNLayer-29 [[1, 48, 2, 160]] [1, 288, 2, 160] 0 \n", " Conv2D-44 [[1, 288, 2, 160]] [1, 288, 2, 160] 7,200 \n", " BatchNorm-30 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152 \n", " ConvBNLayer-30 [[1, 288, 2, 160]] [1, 288, 2, 160] 0 \n", "AdaptiveAvgPool2D-8 [[1, 288, 2, 160]] [1, 288, 1, 1] 0 \n", " Conv2D-45 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808 \n", " Conv2D-46 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024 \n", " SEModule-8 [[1, 288, 2, 160]] [1, 288, 2, 160] 0 \n", " Conv2D-47 [[1, 288, 2, 160]] [1, 48, 2, 160] 13,824 \n", " BatchNorm-31 [[1, 48, 2, 160]] [1, 48, 2, 160] 192 \n", " ConvBNLayer-31 [[1, 288, 2, 160]] [1, 48, 2, 160] 0 \n", " ResidualUnit-10 [[1, 48, 2, 160]] [1, 48, 2, 160] 0 \n", " Conv2D-48 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824 \n", " BatchNorm-32 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152 \n", " ConvBNLayer-32 [[1, 48, 2, 160]] [1, 288, 2, 160] 0 \n", " Conv2D-49 [[1, 288, 2, 160]] [1, 288, 2, 160] 7,200 \n", " BatchNorm-33 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152 \n", " ConvBNLayer-33 [[1, 288, 2, 160]] [1, 288, 2, 160] 0 \n", "AdaptiveAvgPool2D-9 [[1, 288, 2, 160]] [1, 288, 1, 1] 0 \n", " Conv2D-50 [[1, 288, 1, 1]] [1, 72, 1, 1] 20,808 \n", " Conv2D-51 [[1, 72, 1, 1]] [1, 288, 1, 1] 21,024 \n", " SEModule-9 [[1, 288, 2, 160]] [1, 288, 2, 160] 0 \n", " Conv2D-52 [[1, 288, 2, 160]] [1, 48, 2, 160] 13,824 \n", " BatchNorm-34 [[1, 48, 2, 160]] [1, 48, 2, 160] 192 \n", " ConvBNLayer-34 [[1, 288, 2, 160]] [1, 48, 2, 160] 0 \n", " ResidualUnit-11 [[1, 48, 2, 160]] [1, 48, 2, 160] 0 \n", " Conv2D-53 [[1, 48, 2, 160]] [1, 288, 2, 160] 13,824 \n", " BatchNorm-35 [[1, 288, 2, 160]] [1, 288, 2, 160] 1,152 \n", " ConvBNLayer-35 [[1, 48, 2, 160]] [1, 288, 2, 160] 0 \n", " MaxPool2D-1 [[1, 288, 2, 160]] [1, 288, 1, 80] 0 \n", "===============================================================================\n", "Total params: 259,056\n", "Trainable params: 246,736\n", "Non-trainable params: 12,320\n", "-------------------------------------------------------------------------------\n", "Input size (MB): 0.12\n", "Forward/backward pass size (MB): 44.38\n", "Params size (MB): 0.99\n", "Estimated Total Size (MB): 45.48\n", "-------------------------------------------------------------------------------\n", "\n" ] }, { "data": { "text/plain": [ "{'total_params': 259056, 'trainable_params': 246736}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 定义网络输入shape\n", "IMAGE_SHAPE_C = 3\n", "IMAGE_SHAPE_H = 32\n", "IMAGE_SHAPE_W = 320\n", "\n", "\n", "# 可视化网络结构\n", "paddle.summary(MobileNetV3(),[(1, IMAGE_SHAPE_C, IMAGE_SHAPE_H, IMAGE_SHAPE_W)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "backbone output: [1, 288, 1, 80]\n" ] } ], "source": [ "# 图片输入骨干网络\n", "backbone = MobileNetV3()\n", "# 将numpy数据转换为Tensor\n", "input_data = paddle.to_tensor([padding_im])\n", "# 骨干网络输出\n", "feature = backbone(input_data)\n", "# 查看feature map的纬度\n", "print(\"backbone output:\", feature.shape)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* neck\n", "\n", "neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征( [源码位置](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/modeling/necks/rnn.py) ):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Im2Seq(nn.Layer):\n", " def __init__(self, in_channels, **kwargs):\n", " \"\"\"\n", " 图像特征转换为序列特征\n", " :param in_channels: 输入通道数\n", " \"\"\" \n", " super().__init__()\n", " self.out_channels = in_channels\n", "\n", " def forward(self, x):\n", " B, C, H, W = x.shape\n", " assert H == 1\n", " x = x.squeeze(axis=2)\n", " x = x.transpose([0, 2, 1]) # (NWC)(batch, width, channels)\n", " return x\n", "\n", "class EncoderWithRNN(nn.Layer):\n", " def __init__(self, in_channels, hidden_size):\n", " super(EncoderWithRNN, self).__init__()\n", " self.out_channels = hidden_size * 2\n", " self.lstm = nn.LSTM(\n", " in_channels, hidden_size, direction='bidirectional', num_layers=2)\n", "\n", " def forward(self, x):\n", " x, _ = self.lstm(x)\n", " return x\n", "\n", "\n", "class SequenceEncoder(nn.Layer):\n", " def __init__(self, in_channels, hidden_size=48, **kwargs):\n", " \"\"\"\n", " 序列编码\n", " :param in_channels: 输入通道数\n", " :param hidden_size: 隐藏层size\n", " \"\"\" \n", " super(SequenceEncoder, self).__init__()\n", " self.encoder_reshape = Im2Seq(in_channels)\n", "\n", " self.encoder = EncoderWithRNN(\n", " self.encoder_reshape.out_channels, hidden_size)\n", " self.out_channels = self.encoder.out_channels\n", "\n", " def forward(self, x):\n", " x = self.encoder_reshape(x)\n", " x = self.encoder(x)\n", " return x\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sequence shape: [1, 80, 96]\n" ] } ], "source": [ "neck = SequenceEncoder(in_channels=288)\n", "sequence = neck(feature)\n", "print(\"sequence shape:\", sequence.shape)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* head\n", "\n", "预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布,本示例仅支持模型识别小写英文字母和数字(26+10)36个类别([源码位置](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/modeling/heads/rec_ctc_head.py)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class CTCHead(nn.Layer):\n", " def __init__(self,\n", " in_channels,\n", " out_channels,\n", " **kwargs):\n", " \"\"\"\n", " CTC 预测层\n", " :param in_channels: 输入通道数\n", " :param out_channels: 输出通道数\n", " \"\"\" \n", " super(CTCHead, self).__init__()\n", " self.fc = nn.Linear(\n", " in_channels,\n", " out_channels)\n", " \n", " # 思考:out_channels 应该等于多少?\n", " self.out_channels = out_channels\n", "\n", " def forward(self, x):\n", " predicts = self.fc(x)\n", " result = predicts\n", "\n", " if not self.training:\n", " predicts = F.softmax(predicts, axis=2)\n", " result = predicts\n", "\n", " return result" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "在网络随机初始化的情况下,输出结果是无序的,经过SoftMax之后,可以得到各时间步上的概率最大的预测结果,其中:`pred_id` 代表预测的标签ID,`pre_scores` 代表预测结果的置信度:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "predict shape: [1, 80, 37]\n", "pred_id: Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,\n", " [[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])\n", "pred_scores: Tensor(shape=[1, 80], dtype=float32, place=CUDAPlace(0), stop_gradient=False,\n", " [[0.03683758, 0.03368053, 0.03604801, 0.03504696, 0.03696444, 0.03597261, 0.03925638, 0.03650934, 0.03873367, 0.03572492, 0.03543066, 0.03618268, 0.03805700, 0.03496549, 0.03329032, 0.03565763, 0.03846950, 0.03922413, 0.03970327, 0.03638541, 0.03572393, 0.03618102, 0.03565401, 0.03636984, 0.03691722, 0.03718850, 0.03623354, 0.03877943, 0.03731697, 0.03563465, 0.03447339, 0.03365586, 0.03312979, 0.03285240, 0.03273271, 0.03269565, 0.03269779, 0.03271412, 0.03273287, 0.03274929, 0.03276210, 0.03277146, 0.03277802, 0.03278249, 0.03278547, 0.03278742, 0.03278869, 0.03278949, 0.03279000, 0.03279032, 0.03279052, 0.03279064, 0.03279071, 0.03279077, 0.03279081, 0.03279087, 0.03279094, 0.03279106, 0.03279124, 0.03279152, 0.03279196, 0.03279264, 0.03279363, 0.03279509, 0.03279718, 0.03280006, 0.03280392, 0.03280888, 0.03281487, 0.03282148, 0.03282760, 0.03283087, 0.03282646, 0.03280647, 0.03275031, 0.03263619, 0.03242587, 0.03194289, 0.03122442, 0.02986610]])\n" ] } ], "source": [ "ctc_head = CTCHead(in_channels=96, out_channels=37)\n", "predict = ctc_head(sequence)\n", "print(\"predict shape:\", predict.shape)\n", "result = F.softmax(predict, axis=2)\n", "pred_id = paddle.argmax(result, axis=2)\n", "pred_socres = paddle.max(result, axis=2)\n", "print(\"pred_id:\", pred_id)\n", "print(\"pred_scores:\", pred_socres)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* 后处理\n", "\n", "识别网络最终返回的结果是各个时间步上的最大索引值,最终期望的输出是对应的文字结果,因此CRNN的后处理是一个解码过程,主要逻辑如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def decode(text_index, text_prob=None, is_remove_duplicate=False):\n", " \"\"\" convert text-index into text-label. \"\"\"\n", " character = \"-0123456789abcdefghijklmnopqrstuvwxyz\"\n", " result_list = []\n", " # 忽略tokens [0] 代表ctc中的blank位\n", " ignored_tokens = [0]\n", " batch_size = len(text_index)\n", " for batch_idx in range(batch_size):\n", " char_list = []\n", " conf_list = []\n", " for idx in range(len(text_index[batch_idx])):\n", " if text_index[batch_idx][idx] in ignored_tokens:\n", " continue\n", " # 合并blank之间相同的字符\n", " if is_remove_duplicate:\n", " # only for predict\n", " if idx > 0 and text_index[batch_idx][idx - 1] == text_index[\n", " batch_idx][idx]:\n", " continue\n", " # 将解码结果存在char_list内\n", " char_list.append(character[int(text_index[batch_idx][\n", " idx])])\n", " # 记录置信度\n", " if text_prob is not None:\n", " conf_list.append(text_prob[batch_idx][idx])\n", " else:\n", " conf_list.append(1)\n", " text = ''.join(char_list)\n", " # 输出结果\n", " result_list.append((text, np.mean(conf_list)))\n", " return result_list" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "以 head 部分随机初始化预测出的结果为例,进行解码得到:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tensor(shape=[1, 80], dtype=int64, place=CUDAPlace(0), stop_gradient=False,\n", " [[23, 28, 23, 23, 23, 23, 23, 23, 23, 23, 23, 30, 30, 30, 31, 23, 23, 23, 23, 23, 23, 23, 31, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 5 ]])\n", "decode out: [('mrmmmmmmmmmtttummmmmmmummmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmmm4', 0.034180813)]\n" ] } ], "source": [ "pred_id = paddle.argmax(result, axis=2)\n", "pred_socres = paddle.max(result, axis=2)\n", "print(pred_id)\n", "decode_out = decode(pred_id, pred_socres)\n", "print(\"decode out:\", decode_out)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "**小测试:** 如果输入模型训练好的index,解码结果是否正确呢?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out: [('pain', 1.0)]\n" ] } ], "source": [ "# 替换模型预测好的结果\n", "right_pred_id = paddle.to_tensor([['xxxxxxxxxxxxx']])\n", "tmp_scores = paddle.ones(shape=right_pred_id.shape)\n", "out = decode(right_pred_id, tmp_scores)\n", "print(\"out:\",out)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "上述步骤完成了网络的搭建,也实现了一个简单的前向预测过程。\n", "\n", "没有经过训练的网络无法正确预测结果,因此需要定义损失函数、优化策略,将整个网络run起来,下面将详细介绍网络训练原理。\n", "\n", "\n", "## 3. 训练原理详解\n", "### 3.1 准备训练数据\n", "PaddleOCR 支持两种数据格式:\n", " - `lmdb` 用于训练以lmdb格式存储的数据集(LMDBDataSet);\n", " - `通用数据` 用于训练以文本文件存储的数据集(SimpleDataSet);\n", " \n", " 本次只介绍通用数据格式读取\n", "\n", "训练数据的默认存储路径是 `./train_data`, 执行以下命令解压数据:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!cd /home/aistudio/work/train_data/ && tar xf ic15_data.tar " ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "解压完成后,训练图片都在同一个文件夹内,并有一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:\n", "\n", "```\n", "\" 图像文件名 图像标注信息 \"\n", "\n", "train/word_1.png\tGenaxis Theatre\n", "train/word_2.png\t[06]\n", "...\n", "```\n", "\n", "**注意:** txt文件中默认将图片路径和图片标签用 \\t 分割,如用其他方式分割将造成训练报错。\n", "\n", "\n", "数据集应有如下文件结构:\n", "```\n", "|-train_data\n", " |-ic15_data\n", " |- rec_gt_train.txt\n", " |- train\n", " |- word_001.png\n", " |- word_002.jpg\n", " |- word_003.jpg\n", " | ...\n", " |- rec_gt_test.txt\n", " |- test\n", " |- word_001.png\n", " |- word_002.jpg\n", " |- word_003.jpg\n", " | ...\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "确认配置文件中的数据路径是否正确,以 [rec_icdar15_train.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/configs/rec/rec_icdar15_train.yml)为例:\n", "\n", "```\n", "Train:\n", " dataset:\n", " name: SimpleDataSet\n", " # 训练数据根目录\n", " data_dir: ./train_data/ic15_data/\n", " # 训练数据标签\n", " label_file_list: [\"./train_data/ic15_data/rec_gt_train.txt\"]\n", " transforms:\n", " - DecodeImage: # load image\n", " img_mode: BGR\n", " channel_first: False\n", " - CTCLabelEncode: # Class handling label\n", " - RecResizeImg:\n", " image_shape: [3, 32, 100] # [3,32,320]\n", " - KeepKeys:\n", " keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order\n", " loader:\n", " shuffle: True\n", " batch_size_per_card: 256\n", " drop_last: True\n", " num_workers: 8\n", " use_shared_memory: False\n", "\n", "Eval:\n", " dataset:\n", " name: SimpleDataSet\n", " # 评估数据根目录\n", " data_dir: ./train_data/ic15_data\n", " # 评估数据标签\n", " label_file_list: [\"./train_data/ic15_data/rec_gt_test.txt\"]\n", " transforms:\n", " - DecodeImage: # load image\n", " img_mode: BGR\n", " channel_first: False\n", " - CTCLabelEncode: # Class handling label\n", " - RecResizeImg:\n", " image_shape: [3, 32, 100]\n", " - KeepKeys:\n", " keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order\n", " loader:\n", " shuffle: False\n", " drop_last: False\n", " batch_size_per_card: 256\n", " num_workers: 4\n", " use_shared_memory: False\n", " ```" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 3.2 数据预处理\n", "\n", "送入网络的训练数据,需要保证一个batch内维度一致,同时为了不同维度之间的特征在数值上有一定的比较性,需要对数据做统一尺度**缩放**和**归一化**。\n", "\n", "为了增加模型的鲁棒性,抑制过拟合提升泛化性能,需要实现一定的**数据增广**。\n", "\n", "* 缩放和归一化\n", "\n", "第二节中已经介绍了相关内容,这是图片送入网络之前的最后一步操作。调用 `resize_norm_img` 完成图片缩放、padding和归一化。\n", "\n", "* 数据增广\n", "\n", "PaddleOCR中实现了多种数据增广方式,如:颜色反转、随机切割、仿射变化、随机噪声等等,这里以简单的随机切割为例,更多增广方式可参考:[rec_img_aug.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/data/imaug/rec_img_aug.py)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_crop(image):\n", " \"\"\"\n", " random crop\n", " \"\"\"\n", " import random\n", " h, w, _ = image.shape\n", " top_min = 1\n", " top_max = 8\n", " top_crop = int(random.randint(top_min, top_max))\n", " top_crop = min(top_crop, h - 1)\n", " crop_img = image.copy()\n", " ratio = random.randint(0, 1)\n", " if ratio:\n", " crop_img = crop_img[top_crop:h, :, :]\n", " else:\n", " crop_img = crop_img[0:h - top_crop, :, :]\n", " return crop_img\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# 读图\n", "raw_img = cv2.imread(\"/home/aistudio/work/word_1.png\")\n", "plt.figure()\n", "plt.subplot(2,1,1)\n", "# 可视化原图\n", "plt.imshow(raw_img)\n", "# 随机切割\n", "crop_img = get_crop(raw_img)\n", "plt.subplot(2,1,2)\n", "# 可视化增广图\n", "plt.imshow(crop_img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 3.3 训练主程序\n", "\n", "模型训练的入口代码是 [train.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/tools/train.py),它展示了训练中所需的各个模块: `build dataloader`, `build post process`, `build model` , `build loss`, `build optim`, `build metric`,将各部分串联后即可开始训练:\n", "\n", "* 构建 dataloader\n", "\n", "训练模型需要将数据组成指定数目的 batch ,并在训练过程中依次 yield 出来,本例中调用了 PaddleOCR 中实现的 [SimpleDataSet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/data/simple_dataset.py)\n", "\n", "基于原始代码稍作修改,其返回单条数据的主要逻辑如下" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def __getitem__(data_line, data_dir):\n", " import os\n", " mode = \"train\"\n", " delimiter = '\\t'\n", " try:\n", " substr = data_line.strip(\"\\n\").split(delimiter)\n", " file_name = substr[0]\n", " label = substr[1]\n", " img_path = os.path.join(data_dir, file_name)\n", " data = {'img_path': img_path, 'label': label}\n", " if not os.path.exists(img_path):\n", " raise Exception(\"{} does not exist!\".format(img_path))\n", " with open(data['img_path'], 'rb') as f:\n", " img = f.read()\n", " data['image'] = img\n", " # 预处理操作,先注释掉\n", " # outs = transform(data, self.ops)\n", " outs = data\n", " except Exception as e:\n", " print(\"When parsing line {}, error happened with msg: {}\".format(\n", " data_line, e))\n", " outs = None\n", " return outs" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "假设当前输入的标签为 `train/word_1.png\tGenaxis Theatre`, 训练数据的路径为 `/home/aistudio/work/train_data/ic15_data/`, 解析出的结果是一个字典,里面包含 `img_path` `label` `image` 三个字段:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'img_path': '/home/aistudio/work/train_data/ic15_data/train/word_1.png', 'label': 'Genaxis Theatre', 'image': b'\\x89PNG\\r\\n\\x1a\\n\\x00\\x00\\x00\\rIHDR\\x00\\x00\\x00Y\\x00\\x00\\x00\\x0e\\x08\\x02\\x00\\x00\\x00\\xcb\\xe2\\'\\xb7\\x00\\x00\\x00\\x01sRGB\\x00\\xae\\xce\\x1c\\xe9\\x00\\x00\\x00\\x04gAMA\\x00\\x00\\xb1\\x8f\\x0b\\xfca\\x05\\x00\\x00\\x00 cHRM\\x00\\x00z&\\x00\\x00\\x80\\x84\\x00\\x00\\xfa\\x00\\x00\\x00\\x80\\xe8\\x00\\x00u0\\x00\\x00\\xea`\\x00\\x00:\\x98\\x00\\x00\\x17p\\x9c\\xbaQ<\\x00\\x00\\x0bmIDATHK\\x8d\\x96\\xf9S[\\xd7\\x15\\x80\\x01\\xa7\\x93\\xa4\\xfd1\\x99L\\xea\\x80\\xc4\\xa2]B\\x0bb\\xdf\\x84\\x04\\x18\\x8c\\x01\\xb3\\x8aE\\xec\\x12\\x02\\t\\xb4KhC\\xfb\\xbe=\\xed\\xbb\\x04\\xc2l&N\\xd2\\xb4\\x93i\\x9bv\\xa6\\x7fL\\xdb\\xe9d\\xe2N\\xd3dy\\xf2\\xe4\\xeb\\xaf\\xbf\\xbe\\xbc\\xbc|\\xedWU\\xaf]\\xb7\\xab\\xab\\xab\\xca\\xca\\xab\\x8a\\x8a\\x8a\\xca\\xca\\xca\\xeb\\x08\\x97\\xb7\\xa0sY\\xf1\\xda\\xf7\\xdf?\\x83vUYQUU\\x05wa\\xd6w\\xcf\\x9e\\xdd\\xbau\\xeb\\xf5\\xd7_G\\xa1Po\\xfe\\xe6\\xd7\\x9f}\\xf6\\x19,\\x08+|\\xf9\\xe5\\x97U\\xe5\\xd9\\xcf\\xdb\\x1bo\\xbc\\xf1\\xf6\\xdbo\\xdf\\xbe}\\x1b\\xf6\\xf0\\xce;\\xef\\xc0\"\\xdf~\\xfb\\xedw\\xdf}\\xf7\\xf4\\xe9\\xd3\\xaf\\xbe\\xfa\\xea\\x8b/\\xbex\\xf2\\xe4\\xbf\\xffy\\xfc\\xaf[\\x95\\xcf`eh\\xcf\\xa7]\\x96\\xb7\\x01\\x8f\\x95\\xb7q\\xf5\\xe3~*\\xe16l\\xa6\\xb2\\xf2\\xcd\\xab\\x8a[\\xb0\\rh\\xd7\\xfb)?\\x00\\rV\\xae\\xac*O\\xfc\\xe1\\x08\\x97\\x15\\x15U\\x95W\\x97\\xf0W\\x05\\xf3LJ]&\\x14\\x07r\\xa1D\\xdc\\x1d\\xd2\\x8a\\x15\\x0b\\xe3\\\\J-\\x89\\x88&\\xe0\\xaa\\xb1\\xb8j\\x0c\\xbe\\x06K@\\xe1\\x08h\\x0c\\x1e\\xd5@@\\xd7\\x03xT\\x1d\\x1e\\x85\\xc6\\xd5\\xa0p\\xa8jlu\\r\\xb6\\x1aM\\xc1R\\xb0\\xb5x\\xf4\\xed\\xba\\x9aw\\xd1\\x10\\xd1\\xd5\\r5\\xb7\\xeb\\xde{\\x17M\\xc0R\\x18\\xf4\\xd6\\xd9\\x99\\xc5\\x95UAw\\x0f\\xa7\\x91\\xca\\xa4\\xd2\\x9a\\xd1\\xb5\\xd8\\xea\\xf7jQ\\xa82ht\\x1d\\x99\\xdc8<<\\xa2P\\xa8\\xc2\\xe1\\xe8\\xe9\\xe9\\xf9\\xe1\\xe1Q:\\x9d\\x8dF\\xe3\\x1e\\x8f\\x0f\\x06gg\\xe7\\xfazY4\\n\\x9eN\\xc6\\xd2H\\x18\\x80Jl\\x00h\\x84z\\x80\\xd4P\\x03P\\x1aP@#\\x06\\rP\\x1b\\x80\\xba\\x16\\x12\\x9d\\x81\\xa7Q\\xeaI\\xb0\\xf3\\xf2\\xfe\\xd1\\x18r=\\xbe\\x11C\\xa4bIT,\\x81\\x86#\\xde@\\xc7\\x13\\xe8x\\x12\\x03Oh&\\x90\\x80\\n\\x97\\xc1\\x1a\\xb0\\xba5b\\xb9f[\\xa6\\x12J\\x16F\\xa79m\\xacNz\\x07\\xa7\\xbd\\x7f|\\xe0>orq}nmy\\x9a73256p\\xf7\\x1egh\\x84=0\\xd4\\xcb\\x1e\\xe8\\xeaf\\xb7w\\xf4\\xb40\\xbb\\x98\\x8c\\xf6\\xa6\\xa6\\xb1\\xe1\\xf1;\\xec\\xe1\\xde\\x8e\\xbe\\xae\\xd6\\x9e\\xaevVO\\'\\xbb\\xbb\\xa3\\xaf\\xb3\\xad\\xb7\\xbd\\xad\\xa7\\x8f5\\xb8\\xc1\\xdf\\xde\\x16I9\\xfd\\xc3\\xadm\\xdd}\\xec;\\x10{{8,\\x16\\xbb\\xb3\\xb3\\x9b\\xc1`R\\xa9\\xf4\\xbe>\\xce\\xc6\\x86\\xc0d\\xb2 H\\x04\\x14X\\xadv\\x8b\\xc5\\x06\\x972\\x99\\x82\\xcb\\x9d\\xe7@c\\xc3\\x84vVw[w\\x1b\\xa3\\x9dIke\\x90\\x9bi\\xe4\\x16*\\x81\\x8ak\\xa0\\xe2\\xeb\\xe88\\x0c\\x1d_\\x0f\\x91Ah`\\xe0\\xb1\\x0c\\x87K\\xabP\\xcdO\\xcd\\x88\\xf8|\\xbdJ\\x15\\xf2\\xf8b\\xa1\\x84\\xdb\\xea\\x97l)\\xf6\\xf7\\xac\\xd2\\x1d\\xb5\\x90\\xbf;9>\\'X\\x17\\xabU\\xc6M\\xc1\\xae\\\\\\xae\\xdb\\xda\\x92\\t\\x85R\\x91Ha\\xb7\\xfbuz\\xb3i\\xdff1;dR\\xd5\\xf8\\xd8T#\\x85\\x81\\xaa\\xa9\\xef\\xec\\xe8\\x15\\x8b\\xa4>oH.S//\\xadK%J$\\x14\\x8b\\xc7\\xd20\\xe2\\xf3\\x05\\x90P`\\x89\\xc7\\x9d\\x9d\\x990\\x194\\xbb\\xdb\\x9bF\\x9d\\xca\\xa8U\\x8b\\x04k3\\xf7GW\\x17\\xb9\\x06\\xa5\\xdc\\xaa\\xd3l\\xaf\\xadl,\\xce9\\x8c\\xfa\\xb3\\x83\\xc2I\\xbe\\xe0w:\\xb5J\\x99d{sG\\xc8WJ\\xc4R\\xd1&\\x8f;\\x05S\\xb4J\\xa9A#S\\xcbD\"\\xc1\\xca\\xf4\\xf8p?\\xab\\xf5No\\x17\\xb8`\\x82\\x0b\\xcf\\xbe\\x1d\\\\\\xec\\xac\\t\\xf9\\xdc\\xe5\\xd5\\xa9\\x85\\xe5\\x899!\\x8f\\xef4\\x94E\\xe4\\xe3\\xa5\\x93\\xc2E.V4i\\xacz\\xc5~,\\x10?H\\x1f\\x94\\xb2\\xa5\\\\\"\\x93\\x8a$T\\x12\\x85Eo\\n\\xfb\\x02\\xa5\\xdc\\xc1\\xf1\\xc1\\xe9\\xc9\\xc1C\\xc4\\x1b\\x8b\\x87\\xd2\\x01W8\\xe0\\x8d\\x1a\\xb4V\\x9b\\xd9c6:\\x95\\n\\xbd\\xd3\\x110\\x1a\\x1dz\\xbdM\\xa7\\xb3\\xc6\\xe3\\xf9P(\\x19\\xf0G\\xcc&\\xfb\\x96pgrb\\xb6\\xa5\\xb9\\xa3\\xa6\\xba\\x8eAoY[\\x15\\xd8mn\\x10\\x04.\\x14rM&]8*\\x9d\\xa6S\\xf9L&\\x97\\xcf\\xe6\\xe42\\x89T\\xb2\\x93I\\'\\x11\\xbf\\xaf\\x90Ig\\x12q\\xa7\\xd5bP\\xab=6[>\\x99<+\\x1e\\xc4\\xfd\\x01\\x9dL\\xae\\x95\\xca\\x02\\x0e\\xdb\\x87\\x0fO\\x83\\x1e\\xe7\\xbeV\\xe5uX\\xd21$\\x9f\\x8a\\xa5\\xa2\\x81\\xa0\\xc7\\x96\\x8e\\x05\\x8f\\n\\xa9\\xd3\\xa3\\xccI)\\x9dM\\x06\\xf7\\x94\\xa2\\xd9\\xc9;c\\xc3\\x1cH\\x19&\\x81Zv\\x818|z\\xa9\\xda\\xb1g\\x8a\\xb9\\x82~\\x93\\xd3\\xa2\\xde\\xcf\"\\x85\\x83\\xd4\\xe9i\\xf1\\xfd\\x8b\\xa3\\x8f\\x8e\\xb2g\\x01G$\\xe8\\x8c\\x9e\\x1f\\xbe\\x0f9\\xb2\\xaf19M\\xae|\\xa2`T\\x1b\\xdd\\x16w.\\x9e;/\\x9d?:\\xfe\\xe8\\xcf\\x1f\\xff\\xf5({\\x02:\\x82\\xee\\xc8\\x83\\xc2Y!UJ\\xc5\\x0b\\xfbz\\xbbVm\\x8a\\x84R~_\\xd4a\\xf7[\\xcc\\xee|\\xeeA6S\\x8aG\\xd2&\\x83M\\xb0\\xbe=39\\x0f\\xd5\\x04\\xdf\\x17\\n\\x91\\xbe\\xc0]v\\xd9}\\n\\xa9fyq].Q\\xa7\\xe2\\xb9D4c\\xb7\\xb8\\xddN_<\\x9eT\\xa9T:\\x9d\\xee\\xe4\\xe4\\xa4\\x98/\\\\\\x9c?<:,\\x85C\\x88`m]\\xbe+1\\x1b\\x8c\\xc9p\\xf4(_\\x0c\\xb8<6\\xa3)\\x1b\\x8f?:;\\xd6\\xab\\xe5\\xfc\\x95E\\xbbI\\x9fKF\\x11\\x9f\\xcb\\xb8\\'\\xd7\\xa9$`$\\x9b@\\x8a\\xd9\\xc8\\xc5i\\xe1\\xfc8g7kxs\\xa3\\xf7G\\x06\\x9a\\x88\\x94&\\xe2\\xb5\\x8b\\xa0\\xcd\\xa3\\x10\\xee\\x82\\x8b\\x07\\xc9\\xc2a,\\x8b8\\x02\\xb9p1\\x1b-%C\\xb9\\x98?\\xed\\xb7\\x87M\\x1a\\xbb\\xd7\\x1a:?\\xfc\\xc0m\\x0e\\xac/\\x08w\\x05\\x8a\\\\\\xec\\xd0at\\x07\\x9d\\xe1\\xc3\\xf4\\x83\\x8b\\xa3\\x0f\\xc0Q1Yr\\x9b}|\\x1e\\xdc\\x95\\xa5\"9H\\x93b\\xfaH.V\\xef\\x8a\\x94>w\\xd8i\\xf3C\\xa6\\xc8eZ$\\x90\\x005\\xf1H\\xd6\\xebB\\xd4R\\xdd\\x02w\\xb5\\xa7\\x8d]\\x87\\xc2\\xd3\\x88\\xcc\\xd9)\\x9e\\xdf\\x89@\\x89-L/+\\xa4\\xdad8\\x0bSV\\x17\\xf8\\xbb\"y*\\x99\\xb3Y]n\\x97\\xff\\xe2\\xe1\\x87\\xb9L\\xb1\\x98/E\\x90\\xb8\\xdd\\xea\\xda\\\\\\x13*%*\\x9dJ\\x1f\\xf6G\\xe0\\x95 \\x1eD\\xaf\\xd4G\\xfc\\xc8\\xe9\\xe1!\\x94\\xc6\\xd4\\xe8\\xa8A\\xa3LE\\x11\\x9dJ>?=\\xb18;Y\\xcc$R\\xd1P4\\xe8\\xcd%\\xc3\\x80F.\\x86\\xa4`w\\xb7\\x97]\\x10\\xe8\\x15P ^\\x93C\\xbc\\xbai\\x94j\\n\\xe1\\x14\\xb8\\x88{#p\\xfe\\xa0;\\xe1\\xb1\\x85\\xad\\x06\\xafFf\\x12\\xf1\\x15*\\x891\\x16\\xcc\\x01K\\\\\\xc1\\xea\\xfcv\"Tt\\xec\\x07=\\xd6H,\\x90/$\\x8e\\xddF\\x1f\\xa4\\x92\\xcf\\x1c\\x14\\xf2D:\\x89\\xf1A\\xee\\xfc\\xd3\\xdf\\xff\\xed\\x93\\x0f>U\\x8a\\xb5\\xbbB\\xa5\\xc7\\x16\\xb2\\xef{\\xb5\\n\\x93xK\\xe9s\\x84C\\x9ex1}|z\\xf8\\x08D\\x0bV\\xc4\\xdd\\xcdl\\\\-\\xa5\\x99\\xd21yo\\x0e\\xb2O\\xb1\\xa3]\\x9e\\xdd\\xb0\\x1a\\xdcg\\x07\\x8fLZ\\x07\\xf4\\x8d\\x1a\\x1b\\xa4R \\x18\\xf5\\x07\"\\xa5\\xa33\\x9f\\x17q\\xd8\\xbd*\\xa5nG,\\x97\\x88\\x15\\x0e\\xab\\'\\x12\\x88\\x9d\\x1c\\x9e}\\xf2\\xbb?F\\x83q\\xc9\\xb6L\\xab\\xd4\\xe7S\\x05\\xa9H2?\\xbd`1\\x98\\x0f\\xb2E\\x93n\\x7f~zne\\x91\\x97\\x8a\\xc5\\xa3\\xa1`\\xd0\\xebID\\xc2\\xf10\\xa2\\x94JF\\x87\\x87\\xd8=\\xddt\\x12\\x95N\\xa4W\\x98U\\xfa\\x98\\'\\x04:R\\xbeH)\\x9eCl\\xde}\\xa8\\xf0}_\\xc4\\x9fI\\x86\\x8ba_\\xdaf\\xf4I\\xb64\\n\\xb1\\xc1g\\x8f\\x85\\xdc\\xa9\\xb9\\x89\\x8d\\xa9{\\xcb\\x1eK\\xc2\\xbc\\x170i\\xfcvC8\\xee/\\xfa-\\xe1\\x847\\x15\\xb4\"R\\xbe\\xc2\\xa4\\xb4\\x06m\\xe1|\\xe4 \\x85\\xe4\\x1dF\\xafFb\\x00\\xb3N\\xb3\\x1f\\x8e\\x04\\xe7\\x0c\\xfb\\x92\\xf1P\\xf6A\\xee\\xe2\\xe3\\x87\\x7f\\xca\\xc5\\x8e\\xf8\\xab\\xb3\\x0fD\\x94]\\xf0\\xe7W@\\x07\\xfcg\\x85\\x8f\\x85\\x07jxn\\xe5\\xfe\\xc0\\xa8`I\\xa4W\\xd9 /\\xdcVD+\\xb7n\\xf0v\\xf8KR\\xa3\\xda\\xad\\x91Z\\xc7\\x06\\x17\\xef\\xf4Nk$N\\xd9\\x96y{M+\\xde\\xd0{\\xcc\\xf1U\\xcb\\xb3kC\\xac\\x11\\x1c\\x8a\\xdc\\xc9\\xe0\\xac\\xcf\\xef\\x00Sw\\x17a{\\xd3#\\xbc\\x0e:\\xbb\\x85\\xd2C\\xc5\\xb6\\xc2\\x88pE\\xb2\\xbe\\xb0=w\\x7f\\x052\\xb1\\x95\\xdaC\\'\\xb44\\x11\\x98M\\x04F\\x05\\xa9\\x96\\xf0\"d4\\x11.\\xf1\\xb5\\xf8k\\xa0\\xf3\\x02h\\x12\\xbe\\x0c\\xe5\\x07\\xa8x\\xf4\\r\\x14\\x1c\\x8a\\x88C\\xe3_\\t\\xfc\\x1e}\\x19b\\x1d\\x85P\\xdb\\x08\\xdc,u\\xd3\\'\\xd6Qo:7\\xfd\\x1bH\\xf5\\xf0\\xf3\\xb1\\x91\\x8e\\xa1\\xd2\\xb1\\x8d\\xd7\\x90_\\x8c\\x0c\\x1c\\x05h\\xc2\\x91^\\x84\\x81\\xa3\\xd21\\xf0F[\\xaei\\xfe!B\\xa7\\x99\\x8emy\\x15\\xcc&\\\\S\\x13\\x8e^A\\xa8\\xfb\\xd1\\x05Xx\\xce\\xab]\\xfc\\xe8\\xe5\\xb9\\x94\\x1b5e0h\\x1c\\x06\\x8d}9^[\\xc0\\xbd\\x14\\xf1\\x84:\"X~n\\xf0\\xda#,\\x02\\x83\\x84Z2\\xb1\\x9e\\x04\\xa6 \\x92\\xea\\x1bI\\rd\\x88\\x94z2\\x03\\xd3\\xc8\\xc4\\x90\\x99\\x18\\xe2\\xab\\xc031?\\x83\\x81%2\\xb04\\x06\\x96\\xc1\\xc02\\x7fadb\\x9b\\x98/\\xba\\x80\\x8c\\xb8\\x01\\x8c\\xfc\\xfc\\xfd\\xdf$\\xc2O\\'\\x7f!5\\x9e\\xdf\\xc2\\xa0\\x08e\\x11/E,\\x9aX\\x16\\xf1R,\\x8b@\\xe3\\xb0(\\x00{\\x13\\xcb:j\\xf1e\\x11\\xf5\\xc4\\xb2\\x88\\x06\\x80L\\xc6\\x00\\x94\\xc6\\x06\\xc8\\x05\\xd0\\xf1\\n\\xe8\\r\\xa4\\x17\\xf9\\xe9\\x99\\xeb|\\x81\\x04\\xf9e\\x91\\n\"\\x80\\xff\\x03\\x99\\xa0+\\x94\\xbd\\xf0X\\xa1\\x00\\x00\\x00\\x00IEND\\xaeB`\\x82'}\n" ] } ], "source": [ "data_line = \"train/word_1.png\tGenaxis Theatre\"\n", "data_dir = \"/home/aistudio/work/train_data/ic15_data/\"\n", "\n", "item = __getitem__(data_line, data_dir)\n", "print(item)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "实现完单条数据返回逻辑后,调用 `padde.io.Dataloader` 即可把数据组合成batch,具体可参考 [build_dataloader]()\n", "\n", "\n", "* build model\n", "\n", " build model 即搭建主要网络结构,具体细节如《2.3 代码实现》所述,本节不做过多介绍,各模块代码可参考[modeling](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.3/ppocr/modeling)\n", "\n", "* build loss\n", " \n", " CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import paddle.nn as nn\n", "class CTCLoss(nn.Layer):\n", " def __init__(self, use_focal_loss=False, **kwargs):\n", " super(CTCLoss, self).__init__()\n", " # blank 是 ctc 的无意义连接符\n", " self.loss_func = nn.CTCLoss(blank=0, reduction='none')\n", "\n", " def forward(self, predicts, batch):\n", " if isinstance(predicts, (list, tuple)):\n", " predicts = predicts[-1]\n", " # 转置模型 head 层的预测结果,沿channel层排列\n", " predicts = predicts.transpose((1, 0, 2)) #[80,1,37]\n", " N, B, _ = predicts.shape\n", " preds_lengths = paddle.to_tensor([N] * B, dtype='int64')\n", " labels = batch[1].astype(\"int32\")\n", " label_lengths = batch[2].astype('int64')\n", " # 计算损失函数\n", " loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)\n", " loss = loss.mean()\n", " return {'loss': loss}" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* build post process\n", "\n", " 具体细节同样在《2.3 代码实现》有详细介绍,实现逻辑与之前一致。" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* build optim\n", "\n", "优化器使用 `Adam` , 同样调用飞桨API: `paddle.optimizer.Adam`\n", "\n", "* build metric\n", "\n", "metric 部分用于计算模型指标,PaddleOCR的文本识别中,将整句预测正确判断为预测正确,因此准确率计算主要逻辑如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def metric(preds, labels): \n", " correct_num = 0\n", " all_num = 0\n", " norm_edit_dis = 0.0\n", " for (pred), (target) in zip(preds, labels):\n", " pred = pred.replace(\" \", \"\")\n", " target = target.replace(\" \", \"\")\n", " if pred == target:\n", " correct_num += 1\n", " all_num += 1\n", " correct_num += correct_num\n", " all_num += all_num\n", " return {\n", " 'acc': correct_num / all_num,\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc: {'acc': 0.6}\n" ] } ], "source": [ "preds = [\"aaa\", \"bbb\", \"ccc\", \"123\", \"456\"]\n", "labels = [\"aaa\", \"bbb\", \"ddd\", \"123\", \"444\"]\n", "acc = metric(preds, labels)\n", "print(\"acc:\", acc)\n", "# 五个预测结果中,完全正确的有3个,因此准确率应为0.6" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "将以上各部分组合起来,即是完整的训练流程:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "def main(config, device, logger, vdl_writer):\n", " # init dist environment\n", " if config['Global']['distributed']:\n", " dist.init_parallel_env()\n", "\n", " global_config = config['Global']\n", "\n", " # build dataloader\n", " train_dataloader = build_dataloader(config, 'Train', device, logger)\n", " if len(train_dataloader) == 0:\n", " logger.error(\n", " \"No Images in train dataset, please ensure\\n\" +\n", " \"\\t1. The images num in the train label_file_list should be larger than or equal with batch size.\\n\"\n", " +\n", " \"\\t2. The annotation file and path in the configuration file are provided normally.\"\n", " )\n", " return\n", "\n", " if config['Eval']:\n", " valid_dataloader = build_dataloader(config, 'Eval', device, logger)\n", " else:\n", " valid_dataloader = None\n", "\n", " # build post process\n", " post_process_class = build_post_process(config['PostProcess'],\n", " global_config)\n", "\n", " # build model\n", " # for rec algorithm\n", " if hasattr(post_process_class, 'character'):\n", " char_num = len(getattr(post_process_class, 'character'))\n", " if config['Architecture'][\"algorithm\"] in [\"Distillation\",\n", " ]: # distillation model\n", " for key in config['Architecture'][\"Models\"]:\n", " config['Architecture'][\"Models\"][key][\"Head\"][\n", " 'out_channels'] = char_num\n", " else: # base rec model\n", " config['Architecture'][\"Head\"]['out_channels'] = char_num\n", "\n", " model = build_model(config['Architecture'])\n", " if config['Global']['distributed']:\n", " model = paddle.DataParallel(model)\n", "\n", " # build loss\n", " loss_class = build_loss(config['Loss'])\n", "\n", " # build optim\n", " optimizer, lr_scheduler = build_optimizer(\n", " config['Optimizer'],\n", " epochs=config['Global']['epoch_num'],\n", " step_each_epoch=len(train_dataloader),\n", " parameters=model.parameters())\n", "\n", " # build metric\n", " eval_class = build_metric(config['Metric'])\n", " # load pretrain model\n", " pre_best_model_dict = load_model(config, model, optimizer)\n", " logger.info('train dataloader has {} iters'.format(len(train_dataloader)))\n", " if valid_dataloader is not None:\n", " logger.info('valid dataloader has {} iters'.format(\n", " len(valid_dataloader)))\n", "\n", " use_amp = config[\"Global\"].get(\"use_amp\", False)\n", " if use_amp:\n", " AMP_RELATED_FLAGS_SETTING = {\n", " 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,\n", " 'FLAGS_max_inplace_grad_add': 8,\n", " }\n", " paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)\n", " scale_loss = config[\"Global\"].get(\"scale_loss\", 1.0)\n", " use_dynamic_loss_scaling = config[\"Global\"].get(\n", " \"use_dynamic_loss_scaling\", False)\n", " scaler = paddle.amp.GradScaler(\n", " init_loss_scaling=scale_loss,\n", " use_dynamic_loss_scaling=use_dynamic_loss_scaling)\n", " else:\n", " scaler = None\n", "\n", " # start train\n", " program.train(config, train_dataloader, valid_dataloader, device, model,\n", " loss_class, optimizer, lr_scheduler, post_process_class,\n", " eval_class, pre_best_model_dict, logger, vdl_writer, scaler)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 4. 完整训练任务\n", "\n", "### 4.1 启动训练\n", "\n", "PaddleOCR 识别任务与检测任务类似,是通过配置文件传输参数的。\n", "\n", "要进行完整的模型训练,首先需要下载整个项目并安装相关依赖:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)\n", "Collecting scikit-image==0.17.2\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d7/ee/753ea56fda5bc2a5516a1becb631bf5ada593a2dd44f21971a13a762d4db/scikit_image-0.17.2-cp37-cp37m-manylinux1_x86_64.whl (12.5 MB)\n", " |████████████████████████████████| 12.5 MB 8.4 MB/s \n", "\u001b[?25hRequirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)\n", "Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)\n", "Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)\n", "Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.36.1)\n", "Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)\n", "Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)\n", "Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)\n", "Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)\n", "Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)\n", "Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)\n", "Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)\n", "Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)\n", "Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)\n", "Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)\n", "Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)\n", "Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)\n", "Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)\n", "Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)\n", "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)\n", "Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)\n", "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)\n", "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)\n", "Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)\n", "Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)\n", "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)\n", "Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)\n", "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)\n", "Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)\n", "Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)\n", "Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)\n", "Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)\n", "Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)\n", "Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)\n", "Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)\n", "Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)\n", "Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)\n", "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)\n", "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)\n", "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)\n", "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)\n", "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)\n", "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)\n", "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)\n", "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)\n", "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)\n", "Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)\n", "Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)\n", "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)\n", "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)\n", "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)\n", "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)\n", "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)\n", "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)\n", "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)\n", "Installing collected packages: scikit-image\n", " Attempting uninstall: scikit-image\n", " Found existing installation: scikit-image 0.19.1\n", " Uninstalling scikit-image-0.19.1:\n", " Successfully uninstalled scikit-image-0.19.1\n", "Successfully installed scikit-image-0.17.2\n" ] } ], "source": [ "# 克隆PaddleOCR代码\n", "#!git clone https://gitee.com/paddlepaddle/PaddleOCR\n", "# 修改代码运行的默认目录为 /home/aistudio/PaddleOCR\n", "import os\n", "os.chdir(\"/home/aistudio/PaddleOCR\")\n", "# 安装PaddleOCR第三方依赖\n", "!pip install -r requirements.txt" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "创建软链,将训练数据放在PaddleOCR项目下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!ln -s /home/aistudio/work/train_data/ /home/aistudio/PaddleOCR/" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "下载预训练模型:\n", "\n", "为了加快收敛速度,建议下载训练好的模型在 icdar2015 数据上进行 finetune" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2021-12-22 15:39:39-- https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar\n", "Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n", "Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 51200000 (49M) [application/x-tar]\n", "Saving to: ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’\n", "\n", "rec_mv3_none_bilstm 100%[===================>] 48.83M 15.5MB/s in 3.6s \n", "\n", "2021-12-22 15:39:42 (13.7 MB/s) - ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’ saved [51200000/51200000]\n", "\n" ] } ], "source": [ "!cd PaddleOCR/\n", "# 下载MobileNetV3的预训练模型\n", "!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar\n", "# 解压模型参数\n", "!tar -xf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "启动训练命令很简单,指定好配置文件即可。另外在命令行中可以通过 `-o` 修改配置文件中的参数值。启动训练命令如下所示\n", "\n", "其中:\n", "\n", "* `Global.pretrained_model`: 加载的预训练模型路径\n", "* `Global.character_dict_path` : 字典路径(这里只支持26个小写字母+数字)\n", "* `Global.eval_batch_step` : 评估频率\n", "* `Global.epoch_num`: 总训练轮数\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n", "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n", "[2021/12/23 20:28:15] root INFO: Architecture : \n", "[2021/12/23 20:28:15] root INFO: Backbone : \n", "[2021/12/23 20:28:15] root INFO: model_name : large\n", "[2021/12/23 20:28:15] root INFO: name : MobileNetV3\n", "[2021/12/23 20:28:15] root INFO: scale : 0.5\n", "[2021/12/23 20:28:15] root INFO: Head : \n", "[2021/12/23 20:28:15] root INFO: fc_decay : 0\n", "[2021/12/23 20:28:15] root INFO: name : CTCHead\n", "[2021/12/23 20:28:15] root INFO: Neck : \n", "[2021/12/23 20:28:15] root INFO: encoder_type : rnn\n", "[2021/12/23 20:28:15] root INFO: hidden_size : 96\n", "[2021/12/23 20:28:15] root INFO: name : SequenceEncoder\n", "[2021/12/23 20:28:15] root INFO: Transform : None\n", "[2021/12/23 20:28:15] root INFO: algorithm : CRNN\n", "[2021/12/23 20:28:15] root INFO: model_type : rec\n", "[2021/12/23 20:28:15] root INFO: Eval : \n", "[2021/12/23 20:28:15] root INFO: dataset : \n", "[2021/12/23 20:28:15] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 20:28:15] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 20:28:15] root INFO: name : SimpleDataSet\n", "[2021/12/23 20:28:15] root INFO: transforms : \n", "[2021/12/23 20:28:15] root INFO: DecodeImage : \n", "[2021/12/23 20:28:15] root INFO: channel_first : False\n", "[2021/12/23 20:28:15] root INFO: img_mode : BGR\n", "[2021/12/23 20:28:15] root INFO: CTCLabelEncode : None\n", "[2021/12/23 20:28:15] root INFO: RecResizeImg : \n", "[2021/12/23 20:28:15] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 20:28:15] root INFO: KeepKeys : \n", "[2021/12/23 20:28:15] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 20:28:15] root INFO: loader : \n", "[2021/12/23 20:28:15] root INFO: batch_size_per_card : 256\n", "[2021/12/23 20:28:15] root INFO: drop_last : False\n", "[2021/12/23 20:28:15] root INFO: num_workers : 4\n", "[2021/12/23 20:28:15] root INFO: shuffle : False\n", "[2021/12/23 20:28:15] root INFO: use_shared_memory : False\n", "[2021/12/23 20:28:15] root INFO: Global : \n", "[2021/12/23 20:28:15] root INFO: cal_metric_during_train : True\n", "[2021/12/23 20:28:15] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 20:28:15] root INFO: character_type : EN\n", "[2021/12/23 20:28:15] root INFO: checkpoints : None\n", "[2021/12/23 20:28:15] root INFO: debug : False\n", "[2021/12/23 20:28:15] root INFO: distributed : False\n", "[2021/12/23 20:28:15] root INFO: epoch_num : 40\n", "[2021/12/23 20:28:15] root INFO: eval_batch_step : [0, 200]\n", "[2021/12/23 20:28:15] root INFO: infer_img : doc/imgs_words_en/word_19.png\n", "[2021/12/23 20:28:15] root INFO: infer_mode : False\n", "[2021/12/23 20:28:15] root INFO: log_smooth_window : 20\n", "[2021/12/23 20:28:15] root INFO: max_text_length : 25\n", "[2021/12/23 20:28:15] root INFO: pretrained_model : rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy\n", "[2021/12/23 20:28:15] root INFO: print_batch_step : 10\n", "[2021/12/23 20:28:15] root INFO: save_epoch_step : 3\n", "[2021/12/23 20:28:15] root INFO: save_inference_dir : ./\n", "[2021/12/23 20:28:15] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 20:28:15] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 20:28:15] root INFO: use_gpu : True\n", "[2021/12/23 20:28:15] root INFO: use_space_char : False\n", "[2021/12/23 20:28:15] root INFO: use_visualdl : False\n", "[2021/12/23 20:28:15] root INFO: Loss : \n", "[2021/12/23 20:28:15] root INFO: name : CTCLoss\n", "[2021/12/23 20:28:15] root INFO: Metric : \n", "[2021/12/23 20:28:15] root INFO: main_indicator : acc\n", "[2021/12/23 20:28:15] root INFO: name : RecMetric\n", "[2021/12/23 20:28:15] root INFO: Optimizer : \n", "[2021/12/23 20:28:15] root INFO: beta1 : 0.9\n", "[2021/12/23 20:28:15] root INFO: beta2 : 0.999\n", "[2021/12/23 20:28:15] root INFO: lr : \n", "[2021/12/23 20:28:15] root INFO: learning_rate : 0.0005\n", "[2021/12/23 20:28:15] root INFO: name : Adam\n", "[2021/12/23 20:28:15] root INFO: regularizer : \n", "[2021/12/23 20:28:15] root INFO: factor : 0\n", "[2021/12/23 20:28:15] root INFO: name : L2\n", "[2021/12/23 20:28:15] root INFO: PostProcess : \n", "[2021/12/23 20:28:15] root INFO: name : CTCLabelDecode\n", "[2021/12/23 20:28:15] root INFO: Train : \n", "[2021/12/23 20:28:15] root INFO: dataset : \n", "[2021/12/23 20:28:15] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 20:28:15] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:15] root INFO: name : SimpleDataSet\n", "[2021/12/23 20:28:15] root INFO: transforms : \n", "[2021/12/23 20:28:15] root INFO: DecodeImage : \n", "[2021/12/23 20:28:15] root INFO: channel_first : False\n", "[2021/12/23 20:28:15] root INFO: img_mode : BGR\n", "[2021/12/23 20:28:15] root INFO: CTCLabelEncode : None\n", "[2021/12/23 20:28:15] root INFO: RecResizeImg : \n", "[2021/12/23 20:28:15] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 20:28:15] root INFO: KeepKeys : \n", "[2021/12/23 20:28:15] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 20:28:15] root INFO: loader : \n", "[2021/12/23 20:28:15] root INFO: batch_size_per_card : 256\n", "[2021/12/23 20:28:15] root INFO: drop_last : True\n", "[2021/12/23 20:28:15] root INFO: num_workers : 8\n", "[2021/12/23 20:28:15] root INFO: shuffle : True\n", "[2021/12/23 20:28:15] root INFO: use_shared_memory : False\n", "[2021/12/23 20:28:15] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "[2021/12/23 20:28:15] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:15] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']\n", "W1223 20:28:15.851713 306 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n", "W1223 20:28:15.857080 306 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 20:28:19] root INFO: loaded pretrained_model successful from rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy.pdparams\n", "[2021/12/23 20:28:19] root INFO: train dataloader has 17 iters\n", "[2021/12/23 20:28:19] root INFO: valid dataloader has 9 iters\n", "[2021/12/23 20:28:19] root INFO: During the training process, after the 0th iteration, an evaluation is run every 200 iterations\n", "[2021/12/23 20:28:19] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:23] root INFO: epoch: [1/40], iter: 10, lr: 0.000500, loss: 9.336592, acc: 0.203125, norm_edit_dis: 0.674909, reader_cost: 0.27284 s, batch_cost: 0.40185 s, samples: 2816, ips: 700.75290\n", "[2021/12/23 20:28:24] root INFO: epoch: [1/40], iter: 16, lr: 0.000500, loss: 6.955496, acc: 0.210938, norm_edit_dis: 0.678930, reader_cost: 0.00008 s, batch_cost: 0.05430 s, samples: 1536, ips: 2828.80514\n", "[2021/12/23 20:28:24] root INFO: save model in ./output/rec/ic15/latest\n", "[2021/12/23 20:28:24] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:28] root INFO: epoch: [2/40], iter: 20, lr: 0.000500, loss: 6.402417, acc: 0.246094, norm_edit_dis: 0.695874, reader_cost: 0.24180 s, batch_cost: 0.34361 s, samples: 1024, ips: 298.00945\n", "[2021/12/23 20:28:29] root INFO: epoch: [2/40], iter: 30, lr: 0.000500, loss: 4.007382, acc: 0.412109, norm_edit_dis: 0.743064, reader_cost: 0.00013 s, batch_cost: 0.08982 s, samples: 2560, ips: 2849.98954\n", "[2021/12/23 20:28:29] root INFO: epoch: [2/40], iter: 33, lr: 0.000500, loss: 3.906031, acc: 0.458984, norm_edit_dis: 0.770415, reader_cost: 0.00004 s, batch_cost: 0.02684 s, samples: 768, ips: 2861.80304\n", "^C\n", "main proc 306 exit, kill process group 306\n" ] } ], "source": [ "!python3 tools/train.py -c configs/rec/rec_icdar15_train.yml \\\n", " -o Global.pretrained_model=rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy \\\n", " Global.character_dict_path=ppocr/utils/ic15_dict.txt \\\n", " Global.eval_batch_step=[0,200] \\\n", " Global.epoch_num=40" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "根据配置文件中设置的的 `save_model_dir` 字段,会有以下几种参数被保存下来:\n", "\n", "```\n", "output/rec/ic15\n", "├── best_accuracy.pdopt \n", "├── best_accuracy.pdparams \n", "├── best_accuracy.states \n", "├── config.yml \n", "├── iter_epoch_3.pdopt \n", "├── iter_epoch_3.pdparams \n", "├── iter_epoch_3.states \n", "├── latest.pdopt \n", "├── latest.pdparams \n", "├── latest.states \n", "└── train.log\n", "```\n", "其中 best_accuracy.* 是评估集上的最优模型;iter_epoch_x.* 是以 `save_epoch_step` 为间隔保存下来的模型;latest.* 是最后一个epoch的模型。\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "**总结:**\n", "\n", "如果需要训练自己的数据需要修改:\n", "\n", "1. 训练和评估数据路径(必须)\n", "2. 字典路径(必须)\n", "3. 预训练模型 (可选)\n", "4. 学习率、image shape、网络结构(可选)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 4.2 模型评估\n", "\n", "\n", "评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。\n", "\n", "这里默认使用 icdar2015 的评估集,加载刚刚训练好的模型权重:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 14:27:51] root INFO: Architecture : \n", "[2021/12/23 14:27:51] root INFO: Backbone : \n", "[2021/12/23 14:27:51] root INFO: model_name : large\n", "[2021/12/23 14:27:51] root INFO: name : MobileNetV3\n", "[2021/12/23 14:27:51] root INFO: scale : 0.5\n", "[2021/12/23 14:27:51] root INFO: Head : \n", "[2021/12/23 14:27:51] root INFO: fc_decay : 0\n", "[2021/12/23 14:27:51] root INFO: name : CTCHead\n", "[2021/12/23 14:27:51] root INFO: Neck : \n", "[2021/12/23 14:27:51] root INFO: encoder_type : rnn\n", "[2021/12/23 14:27:51] root INFO: hidden_size : 96\n", "[2021/12/23 14:27:51] root INFO: name : SequenceEncoder\n", "[2021/12/23 14:27:51] root INFO: Transform : None\n", "[2021/12/23 14:27:51] root INFO: algorithm : CRNN\n", "[2021/12/23 14:27:51] root INFO: model_type : rec\n", "[2021/12/23 14:27:51] root INFO: Eval : \n", "[2021/12/23 14:27:51] root INFO: dataset : \n", "[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 14:27:51] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:27:51] root INFO: transforms : \n", "[2021/12/23 14:27:51] root INFO: DecodeImage : \n", "[2021/12/23 14:27:51] root INFO: channel_first : False\n", "[2021/12/23 14:27:51] root INFO: img_mode : BGR\n", "[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:27:51] root INFO: RecResizeImg : \n", "[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:27:51] root INFO: KeepKeys : \n", "[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:27:51] root INFO: loader : \n", "[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:27:51] root INFO: drop_last : False\n", "[2021/12/23 14:27:51] root INFO: num_workers : 4\n", "[2021/12/23 14:27:51] root INFO: shuffle : False\n", "[2021/12/23 14:27:51] root INFO: use_shared_memory : False\n", "[2021/12/23 14:27:51] root INFO: Global : \n", "[2021/12/23 14:27:51] root INFO: cal_metric_during_train : True\n", "[2021/12/23 14:27:51] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 14:27:51] root INFO: character_type : EN\n", "[2021/12/23 14:27:51] root INFO: checkpoints : output/rec/ic15/best_accuracy\n", "[2021/12/23 14:27:51] root INFO: debug : False\n", "[2021/12/23 14:27:51] root INFO: distributed : False\n", "[2021/12/23 14:27:51] root INFO: epoch_num : 72\n", "[2021/12/23 14:27:51] root INFO: eval_batch_step : [0, 2000]\n", "[2021/12/23 14:27:51] root INFO: infer_img : doc/imgs_words_en/word_10.png\n", "[2021/12/23 14:27:51] root INFO: infer_mode : False\n", "[2021/12/23 14:27:51] root INFO: log_smooth_window : 20\n", "[2021/12/23 14:27:51] root INFO: max_text_length : 25\n", "[2021/12/23 14:27:51] root INFO: pretrained_model : None\n", "[2021/12/23 14:27:51] root INFO: print_batch_step : 10\n", "[2021/12/23 14:27:51] root INFO: save_epoch_step : 3\n", "[2021/12/23 14:27:51] root INFO: save_inference_dir : ./\n", "[2021/12/23 14:27:51] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 14:27:51] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 14:27:51] root INFO: use_gpu : True\n", "[2021/12/23 14:27:51] root INFO: use_space_char : False\n", "[2021/12/23 14:27:51] root INFO: use_visualdl : False\n", "[2021/12/23 14:27:51] root INFO: Loss : \n", "[2021/12/23 14:27:51] root INFO: name : CTCLoss\n", "[2021/12/23 14:27:51] root INFO: Metric : \n", "[2021/12/23 14:27:51] root INFO: main_indicator : acc\n", "[2021/12/23 14:27:51] root INFO: name : RecMetric\n", "[2021/12/23 14:27:51] root INFO: Optimizer : \n", "[2021/12/23 14:27:51] root INFO: beta1 : 0.9\n", "[2021/12/23 14:27:51] root INFO: beta2 : 0.999\n", "[2021/12/23 14:27:51] root INFO: lr : \n", "[2021/12/23 14:27:51] root INFO: learning_rate : 0.0005\n", "[2021/12/23 14:27:51] root INFO: name : Adam\n", "[2021/12/23 14:27:51] root INFO: regularizer : \n", "[2021/12/23 14:27:51] root INFO: factor : 0\n", "[2021/12/23 14:27:51] root INFO: name : L2\n", "[2021/12/23 14:27:51] root INFO: PostProcess : \n", "[2021/12/23 14:27:51] root INFO: name : CTCLabelDecode\n", "[2021/12/23 14:27:51] root INFO: Train : \n", "[2021/12/23 14:27:51] root INFO: dataset : \n", "[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 14:27:51] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:27:51] root INFO: transforms : \n", "[2021/12/23 14:27:51] root INFO: DecodeImage : \n", "[2021/12/23 14:27:51] root INFO: channel_first : False\n", "[2021/12/23 14:27:51] root INFO: img_mode : BGR\n", "[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:27:51] root INFO: RecResizeImg : \n", "[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:27:51] root INFO: KeepKeys : \n", "[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:27:51] root INFO: loader : \n", "[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:27:51] root INFO: drop_last : True\n", "[2021/12/23 14:27:51] root INFO: num_workers : 8\n", "[2021/12/23 14:27:51] root INFO: shuffle : True\n", "[2021/12/23 14:27:51] root INFO: use_shared_memory : False\n", "[2021/12/23 14:27:51] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "[2021/12/23 14:27:51] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']\n", "W1223 14:27:51.861889 5192 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", "W1223 14:27:51.865501 5192 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 14:27:56] root INFO: resume from output/rec/ic15/best_accuracy\n", "[2021/12/23 14:27:56] root INFO: metric in ckpt ***************\n", "[2021/12/23 14:27:56] root INFO: acc:0.48531535869041886\n", "[2021/12/23 14:27:56] root INFO: norm_edit_dis:0.7895228681338454\n", "[2021/12/23 14:27:56] root INFO: fps:3266.1877400927865\n", "[2021/12/23 14:27:56] root INFO: best_epoch:24\n", "[2021/12/23 14:27:56] root INFO: start_epoch:25\n", "eval model:: 100%|████████████████████████████████| 9/9 [00:02<00:00, 3.32it/s]\n", "[2021/12/23 14:27:59] root INFO: metric eval ***************\n", "[2021/12/23 14:27:59] root INFO: acc:0.48531535869041886\n", "[2021/12/23 14:27:59] root INFO: norm_edit_dis:0.7895228681338454\n", "[2021/12/23 14:27:59] root INFO: fps:4491.015930181665\n" ] } ], "source": [ "!python tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy \\\n", " Global.character_dict_path=ppocr/utils/ic15_dict.txt\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "评估后,可以看到训练模型在验证集上的精度。\n", "\n", "PaddleOCR支持训练和评估交替进行, 可在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每2000个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec/ic15/best_accuracy` 。\n", "\n", "如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 4.3 预测\n", "\n", "使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。\n", "\n", "预测图片:\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.3/doc/imgs_words_en/word_19.png)\n", "\n", "默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 14:29:19] root INFO: Architecture : \n", "[2021/12/23 14:29:19] root INFO: Backbone : \n", "[2021/12/23 14:29:19] root INFO: model_name : large\n", "[2021/12/23 14:29:19] root INFO: name : MobileNetV3\n", "[2021/12/23 14:29:19] root INFO: scale : 0.5\n", "[2021/12/23 14:29:19] root INFO: Head : \n", "[2021/12/23 14:29:19] root INFO: fc_decay : 0\n", "[2021/12/23 14:29:19] root INFO: name : CTCHead\n", "[2021/12/23 14:29:19] root INFO: Neck : \n", "[2021/12/23 14:29:19] root INFO: encoder_type : rnn\n", "[2021/12/23 14:29:19] root INFO: hidden_size : 96\n", "[2021/12/23 14:29:19] root INFO: name : SequenceEncoder\n", "[2021/12/23 14:29:19] root INFO: Transform : None\n", "[2021/12/23 14:29:19] root INFO: algorithm : CRNN\n", "[2021/12/23 14:29:19] root INFO: model_type : rec\n", "[2021/12/23 14:29:19] root INFO: Eval : \n", "[2021/12/23 14:29:19] root INFO: dataset : \n", "[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 14:29:19] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:29:19] root INFO: transforms : \n", "[2021/12/23 14:29:19] root INFO: DecodeImage : \n", "[2021/12/23 14:29:19] root INFO: channel_first : False\n", "[2021/12/23 14:29:19] root INFO: img_mode : BGR\n", "[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:29:19] root INFO: RecResizeImg : \n", "[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:29:19] root INFO: KeepKeys : \n", "[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:29:19] root INFO: loader : \n", "[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:29:19] root INFO: drop_last : False\n", "[2021/12/23 14:29:19] root INFO: num_workers : 4\n", "[2021/12/23 14:29:19] root INFO: shuffle : False\n", "[2021/12/23 14:29:19] root INFO: use_shared_memory : False\n", "[2021/12/23 14:29:19] root INFO: Global : \n", "[2021/12/23 14:29:19] root INFO: cal_metric_during_train : True\n", "[2021/12/23 14:29:19] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 14:29:19] root INFO: character_type : EN\n", "[2021/12/23 14:29:19] root INFO: checkpoints : output/rec/ic15/best_accuracy\n", "[2021/12/23 14:29:19] root INFO: debug : False\n", "[2021/12/23 14:29:19] root INFO: distributed : False\n", "[2021/12/23 14:29:19] root INFO: epoch_num : 72\n", "[2021/12/23 14:29:19] root INFO: eval_batch_step : [0, 2000]\n", "[2021/12/23 14:29:19] root INFO: infer_img : doc/imgs_words_en/word_19.png\n", "[2021/12/23 14:29:19] root INFO: infer_mode : False\n", "[2021/12/23 14:29:19] root INFO: log_smooth_window : 20\n", "[2021/12/23 14:29:19] root INFO: max_text_length : 25\n", "[2021/12/23 14:29:19] root INFO: pretrained_model : None\n", "[2021/12/23 14:29:19] root INFO: print_batch_step : 10\n", "[2021/12/23 14:29:19] root INFO: save_epoch_step : 3\n", "[2021/12/23 14:29:19] root INFO: save_inference_dir : ./\n", "[2021/12/23 14:29:19] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 14:29:19] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 14:29:19] root INFO: use_gpu : True\n", "[2021/12/23 14:29:19] root INFO: use_space_char : False\n", "[2021/12/23 14:29:19] root INFO: use_visualdl : False\n", "[2021/12/23 14:29:19] root INFO: Loss : \n", "[2021/12/23 14:29:19] root INFO: name : CTCLoss\n", "[2021/12/23 14:29:19] root INFO: Metric : \n", "[2021/12/23 14:29:19] root INFO: main_indicator : acc\n", "[2021/12/23 14:29:19] root INFO: name : RecMetric\n", "[2021/12/23 14:29:19] root INFO: Optimizer : \n", "[2021/12/23 14:29:19] root INFO: beta1 : 0.9\n", "[2021/12/23 14:29:19] root INFO: beta2 : 0.999\n", "[2021/12/23 14:29:19] root INFO: lr : \n", "[2021/12/23 14:29:19] root INFO: learning_rate : 0.0005\n", "[2021/12/23 14:29:19] root INFO: name : Adam\n", "[2021/12/23 14:29:19] root INFO: regularizer : \n", "[2021/12/23 14:29:19] root INFO: factor : 0\n", "[2021/12/23 14:29:19] root INFO: name : L2\n", "[2021/12/23 14:29:19] root INFO: PostProcess : \n", "[2021/12/23 14:29:19] root INFO: name : CTCLabelDecode\n", "[2021/12/23 14:29:19] root INFO: Train : \n", "[2021/12/23 14:29:19] root INFO: dataset : \n", "[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 14:29:19] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:29:19] root INFO: transforms : \n", "[2021/12/23 14:29:19] root INFO: DecodeImage : \n", "[2021/12/23 14:29:19] root INFO: channel_first : False\n", "[2021/12/23 14:29:19] root INFO: img_mode : BGR\n", "[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:29:19] root INFO: RecResizeImg : \n", "[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:29:19] root INFO: KeepKeys : \n", "[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:29:19] root INFO: loader : \n", "[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:29:19] root INFO: drop_last : True\n", "[2021/12/23 14:29:19] root INFO: num_workers : 8\n", "[2021/12/23 14:29:19] root INFO: shuffle : True\n", "[2021/12/23 14:29:19] root INFO: use_shared_memory : False\n", "[2021/12/23 14:29:19] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "W1223 14:29:19.803710 5290 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", "W1223 14:29:19.807695 5290 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 14:29:25] root INFO: resume from output/rec/ic15/best_accuracy\n", "[2021/12/23 14:29:25] root INFO: infer_img: doc/imgs_words_en/word_19.png\n", "pred idx: Tensor(shape=[1, 25], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [[29, 0 , 0 , 0 , 22, 0 , 0 , 0 , 25, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 33]])\n", "[2021/12/23 14:29:25] root INFO: \t result: slow\t0.8795223\n", "[2021/12/23 14:29:25] root INFO: success!\n" ] } ], "source": [ "!python tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy Global.character_dict_path=ppocr/utils/ic15_dict.txt" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "得到输入图像的预测结果:\n", "\n", "```\n", "infer_img: doc/imgs_words_en/word_19.png\n", " result: slow\t0.8795223\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 作业\n", "\n", "**【题目1】**\n", "\n", "可视化出 PaddleOCR 中的实现的[数据增强](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/rec_img_aug.py)结果:noise、jitter, 并用语言解释效果。\n", "\n", "可选测试图片:\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_1.jpg)\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_2.jpg)\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_3.jpg)\n", "\n", "\n", "**【题目2】**\n", "\n", "更换 configs/rec/rec_icdar15_train.yml 配置中的 backbone 为 PaddleOCR 中的 [ResNet34_vd](https://github.com/PaddlePaddle/PaddleOCR/blob/6ee301be36eb54d91dc437842f754593dce13967/ppocr/modeling/backbones/rec_resnet_vd.py#L176),当输入图片shape为(3,32,100)时,Head 层最终输出的特征尺寸是多少?\n", "\n", "\n", "**【题目3】**\n", "\n", "下载10W中文数据集[rec_data_lesson_demo](https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar),修改 configs/rec/rec_icdar15_train.yml 配置文件训练一个识别模型,提供训练log。\n", "\n", "可加载预训练模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar \n", "\n", "\n", "## 总结\n", "\n", "至此,一个基于CRNN的文本识别任务就全部完成了,更多功能和代码可以参考 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)。\n", "\n", "如果对项目任何问题或者疑问,欢迎在评论区留言提出" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "py35-paddle1.2.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }