{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "# 文本识别实战\n", "\n", "上一章理论部分,介绍了文本识别领域的主要方法,其中CRNN是较早被提出也是目前工业界应用较多的方法。本章将详细介绍如何基于PaddleOCR完成CRNN文本识别模型的搭建、训练、评估和预测。数据集采用 icdar 2015,其中训练集有4468张,测试集有2077张。\n", "\n", "\n", "通过本章的学习,你可以掌握:\n", "\n", "1. 如何使用PaddleOCR whl包快速完成文本识别预测\n", "\n", "2. CRNN的基本原理和网络结构\n", "\n", "3. 模型训练的必须步骤和调参方式\n", "\n", "4. 使用自定义的数据集训练网络\n", "\n", "注:`paddleocr`指代`PaddleOCR whl包`" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 1. 快速体验\n", "\n", "### 1.1 安装相关的依赖及whl包\n", "\n", "首先确认安装了 paddle 以及 paddleocr,如果已经安装过,忽略该步骤。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Requirement already satisfied: paddlepaddle-gpu in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.1.2.post101)\n", "Requirement already satisfied: protobuf>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (3.14.0)\n", "Requirement already satisfied: numpy>=1.13; python_version >= \"3.5\" and platform_system != \"Windows\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.20.3)\n", "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (7.1.2)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (1.15.0)\n", "Requirement already satisfied: requests>=2.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (2.22.0)\n", "Requirement already satisfied: astor in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.8.1)\n", "Requirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (4.4.2)\n", "Requirement already satisfied: gast<=0.4.0,>=0.3.3; platform_system != \"Windows\" in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlepaddle-gpu) (0.3.3)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (1.25.6)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2019.9.11)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (2.8)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests>=2.20.0->paddlepaddle-gpu) (3.0.4)\n", "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Collecting pip\n", "\u001b[?25l Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a4/6d/6463d49a933f547439d6b5b98b46af8742cc03ae83543e4d7688c2420f8b/pip-21.3.1-py3-none-any.whl (1.7MB)\n", "\u001b[K |████████████████████████████████| 1.7MB 8.4MB/s eta 0:00:01\n", "\u001b[?25hInstalling collected packages: pip\n", " Found existing installation: pip 19.2.3\n", " Uninstalling pip-19.2.3:\n", " Successfully uninstalled pip-19.2.3\n", "Successfully installed pip-21.3.1\n", "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Collecting paddleocr\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e1/b6/5486e674ce096667dff247b58bf0fb789c2ce17a10e546c2686a2bb07aec/paddleocr-2.3.0.2-py3-none-any.whl (250 kB)\n", " |████████████████████████████████| 250 kB 3.3 MB/s \n", "\u001b[?25hCollecting lmdb\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2e/dd/ada2fd91cd7832979069c556607903f274470c3d3d2274e0a848908272e8/lmdb-1.2.1-cp37-cp37m-manylinux2010_x86_64.whl (299 kB)\n", " |████████████████████████████████| 299 kB 12.8 MB/s \n", "\u001b[?25hCollecting lxml\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7b/01/16a9b80c8ce4339294bb944f08e157dbfcfbb09ba9031bde4ddf7e3e5499/lxml-4.7.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.4 MB)\n", " |████████████████████████████████| 6.4 MB 52.4 MB/s \n", "\u001b[?25hCollecting python-Levenshtein\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/dc/97f2b63ef0fa1fd78dcb7195aca577804f6b2b51e712516cc0e902a9a201/python-Levenshtein-0.12.2.tar.gz (50 kB)\n", " |████████████████████████████████| 50 kB 1.6 MB/s \n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hCollecting scikit-image\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/9a/44/8f8c7f9c9de7fde70587a656d7df7d056e6f05192a74491f7bc074a724d0/scikit_image-0.19.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.3 MB)\n", " |████████████████████████████████| 13.3 MB 56.1 MB/s \n", "\u001b[?25hRequirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (1.20.3)\n", "Collecting imgaug==0.4.0\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/b1/af3142c4a85cba6da9f4ebb5ff4e21e2616309552caca5e8acefe9840622/imgaug-0.4.0-py2.py3-none-any.whl (948 kB)\n", " |████████████████████████████████| 948 kB 62.9 MB/s \n", "\u001b[?25hCollecting opencv-contrib-python==4.4.0.46\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/08/51/1e0a206dd5c70fea91084e6f43979dc13e8eb175760cc7a105083ec3eb68/opencv_contrib_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (55.7 MB)\n", " |████████████████████████████████| 55.7 MB 44 kB/s 0:01\n", "\u001b[?25hCollecting premailer\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/b1/07/4e8d94f94c7d41ca5ddf8a9695ad87b888104e2fd41a35546c1dc9ca74ac/premailer-3.10.0-py2.py3-none-any.whl (19 kB)\n", "Collecting shapely\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/20/33ce377bd24d122a4d54e22ae2c445b9b1be8240edb50040b40add950cd9/Shapely-1.8.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n", " |████████████████████████████████| 1.1 MB 14.5 MB/s \n", "\u001b[?25hRequirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (2.2.0)\n", "Collecting fasttext==0.9.1\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/10/61/2e01f1397ec533756c1d893c22d9d5ed3fce3a6e4af1976e0d86bb13ea97/fasttext-0.9.1.tar.gz (57 kB)\n", " |████████████████████████████████| 57 kB 9.0 MB/s \n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (0.29)\n", "Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (3.0.5)\n", "Collecting pyclipper\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c5/fa/2c294127e4f88967149a68ad5b3e43636e94e3721109572f8f17ab15b772/pyclipper-1.3.0.post2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (603 kB)\n", " |████████████████████████████████| 603 kB 7.6 MB/s \n", "\u001b[?25hRequirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleocr) (4.36.1)\n", "Collecting pybind11>=2.2\n", " Using cached https://pypi.tuna.tsinghua.edu.cn/packages/a8/3b/fc246e1d4c7547a7a07df830128e93c6215e9b93dcb118b2a47a70726153/pybind11-2.8.1-py2.py3-none-any.whl (208 kB)\n", "Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->paddleocr) (56.2.0)\n", "Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (7.1.2)\n", "Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.6.1)\n", "Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.6.3)\n", "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (4.1.1.26)\n", "Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (2.2.3)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->paddleocr) (1.15.0)\n", "Collecting PyWavelets>=1.1.1\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a1/9c/564511b6e1c4e1d835ed2d146670436036960d09339a8fa2921fe42dad08/PyWavelets-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (6.1 MB)\n", " |████████████████████████████████| 6.1 MB 3.8 MB/s \n", "\u001b[?25hRequirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (20.9)\n", "Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->paddleocr) (2.4)\n", "Collecting tifffile>=2019.7.26\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d8/38/85ae5ed77598ca90558c17a2f79ddaba33173b31cf8d8f545d34d9134f0d/tifffile-2021.11.2-py3-none-any.whl (178 kB)\n", " |████████████████████████████████| 178 kB 7.1 MB/s \n", "\u001b[?25hRequirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.0.1)\n", "Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->paddleocr) (1.4.1)\n", "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (2.22.0)\n", "Collecting cssselect\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/d4/3b5c17f00cce85b9a1e6f91096e1cc8e8ede2e1be8e96b87ce1ed09e92c5/cssselect-1.1.0-py2.py3-none-any.whl (16 kB)\n", "Collecting cssutils\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/24/c4/9db28fe567612896d360ab28ad02ee8ae107d0e92a22db39affd3fba6212/cssutils-2.3.0-py3-none-any.whl (404 kB)\n", " |████████████████████████████████| 404 kB 134 kB/s \n", "\u001b[?25hRequirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->paddleocr) (4.0.0)\n", "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.21.0)\n", "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.0.0)\n", "Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.1)\n", "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.8.2)\n", "Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.7.1.1)\n", "Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (1.1.5)\n", "Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (0.8.53)\n", "Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->paddleocr) (3.14.0)\n", "Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.2.0)\n", "Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (2.6.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.23)\n", "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->paddleocr) (0.6.1)\n", "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (1.1.0)\n", "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (2.11.0)\n", "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (7.0)\n", "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->paddleocr) (0.16.0)\n", "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2019.3)\n", "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->paddleocr) (2.8.0)\n", "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->paddleocr) (4.4.2)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->paddleocr) (2.4.2)\n", "Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (3.9.9)\n", "Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->paddleocr) (0.18.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (1.1.0)\n", "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (2.8.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->paddleocr) (0.10.0)\n", "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.0)\n", "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (16.7.9)\n", "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.3.4)\n", "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (5.1.2)\n", "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (2.0.1)\n", "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (1.4.10)\n", "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->paddleocr) (0.10.0)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (1.25.6)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (3.0.4)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->premailer->paddleocr) (2019.9.11)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->paddleocr) (1.1.1)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->paddleocr) (3.6.0)\n", "Building wheels for collected packages: fasttext, python-Levenshtein\n", " Building wheel for fasttext (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for fasttext: filename=fasttext-0.9.1-cp37-cp37m-linux_x86_64.whl size=2584156 sha256=acb4d4fde73d31c7dfdd2ae3de0da25a558c34c672d4904e6a5c4279185fe5af\n", " Stored in directory: /home/aistudio/.cache/pip/wheels/a1/cb/b3/a25a8ce16c1a4ff102c1e40d6eaa4dfc9d5695b92d57331b36\n", " Building wheel for python-Levenshtein (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for python-Levenshtein: filename=python_Levenshtein-0.12.2-cp37-cp37m-linux_x86_64.whl size=171687 sha256=56b4a2de4349a05004121050df68b488ffd253dcc59187ca07b89b62d40c0218\n", " Stored in directory: /home/aistudio/.cache/pip/wheels/38/b9/a4/3729726160fb103833de468adb5ce019b58543ae41d0b0e446\n", "Successfully built fasttext python-Levenshtein\n", "Installing collected packages: tifffile, PyWavelets, shapely, scikit-image, pybind11, lxml, cssutils, cssselect, python-Levenshtein, pyclipper, premailer, opencv-contrib-python, lmdb, imgaug, fasttext, paddleocr\n", "Successfully installed PyWavelets-1.2.0 cssselect-1.1.0 cssutils-2.3.0 fasttext-0.9.1 imgaug-0.4.0 lmdb-1.2.1 lxml-4.7.1 opencv-contrib-python-4.4.0.46 paddleocr-2.3.0.2 premailer-3.10.0 pybind11-2.8.1 pyclipper-1.3.0.post2 python-Levenshtein-0.12.2 scikit-image-0.19.1 shapely-1.8.0 tifffile-2021.11.2\n" ] } ], "source": [ "# 安装 PaddlePaddle GPU 版本\n", "!pip install paddlepaddle-gpu\n", "# 安装 PaddleOCR whl包\n", "! pip install -U pip\n", "! pip install paddleocr" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 1.2 快速预测文字内容\n", "\n", "PaddleOCR whl包会自动下载ppocr轻量级模型作为默认模型\n", "\n", "下面展示如何使用whl包进行识别预测:\n", "\n", "测试图片:\n", "\n", "![](https://ai-studio-static-online.cdn.bcebos.com/531d9b3aff45449893b33bcb5dd13971057fcb4038f045578b3abd99fa3a96f2)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 20:28:44] root WARNING: version 2.1 not support cls models, use version 2.0 instead\n", "download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar to /home/aistudio/.paddleocr/2.2.1/ocr/det/ch/ch_PP-OCRv2_det_infer/ch_PP-OCRv2_det_infer.tar\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n", "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n", " 0%| | 0.00/3.19M [00:00, ?iB/s]100%|██████████| 3.19M/3.19M [00:00<00:00, 7.80MiB/s]\n", " 14%|█▎ | 1.20M/8.88M [00:00<00:00, 11.7MiB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "download https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar to /home/aistudio/.paddleocr/2.2.1/ocr/rec/ch/ch_PP-OCRv2_rec_infer/ch_PP-OCRv2_rec_infer.tar\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 24%|██▍ | 2.15M/8.88M [00:00<00:00, 10.8MiB/s]100%|██████████| 8.88M/8.88M [00:01<00:00, 6.38MiB/s]\n", " 17%|█▋ | 249k/1.45M [00:00<00:00, 2.42MiB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "download https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar to /home/aistudio/.paddleocr/2.2.1/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer/ch_ppocr_mobile_v2.0_cls_infer.tar\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 90%|█████████ | 1.31M/1.45M [00:00<00:00, 3.32MiB/s]100%|██████████| 1.45M/1.45M [00:00<00:00, 4.53MiB/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/det/ch/ch_PP-OCRv2_det_infer', det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir=None, ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, output='./output/table', precision='fp32', process_id=0, rec=True, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/home/aistudio/PaddleOCR/ppocr/utils/ppocr_keys_v1.txt', rec_char_type='ch', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, version='2.1', vis_font_path='./doc/fonts/simfang.ttf', warmup=True)\n", "[2021/12/23 20:28:48] root WARNING: Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process\n", "('SLOW', 0.9776376)\n" ] } ], "source": [ "from paddleocr import PaddleOCR\n", "\n", "ocr = PaddleOCR() # need to run only once to download and load model into memory\n", "img_path = '/home/aistudio/work/word_19.png'\n", "result = ocr.ocr(img_path, det=False)\n", "for line in result:\n", " print(line)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "执行完上述代码块,将返回识别结果和识别置信度\n", "\n", "```\n", "('SLOW', 0.9776376)\n", "```\n", "\n", "至此,你掌握了如何使用 paddleocr whl 包进行预测。`./work/` 路径下有更多测试图片,可以尝试其他图片结果。" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 2. 预测原理详解\n", "\n", "第一节中 paddleocr 加载训练好的 CRNN 识别模型进行预测,本节将详细介绍 CRNN 的原理及流程。\n", "\n", "### 2.1 所属类别\n", "\n", "CRNN 是基于CTC的算法,在理论部分介绍的分类图中,处在如下位置。可以看出CRNN主要用于解决规则文本,基于CTC的算法有较快的预测速度并且很好的适用长文本。因此CRNN是PPOCR选择的中文识别算法。\n", "
\\xab\\xb3\\x0fD\\x94]\\xf0\\xe7W@\\x07\\xfcg\\x85\\x8f\\x85\\x07jxn\\xe5\\xfe\\xc0\\xa8`I\\xa4W\\xd9 /\\xdcVD+\\xb7n\\xf0v\\xf8KR\\xa3\\xda\\xad\\x91Z\\xc7\\x06\\x17\\xef\\xf4Nk$N\\xd9\\x96y{M+\\xde\\xd0{\\xcc\\xf1U\\xcb\\xb3kC\\xac\\x11\\x1c\\x8a\\xdc\\xc9\\xe0\\xac\\xcf\\xef\\x00Sw\\x17a{\\xd3#\\xbc\\x0e:\\xbb\\x85\\xd2C\\xc5\\xb6\\xc2\\x88pE\\xb2\\xbe\\xb0=w\\x7f\\x052\\xb1\\x95\\xdaC\\'\\xb44\\x11\\x98M\\x04F\\x05\\xa9\\x96\\xf0\"d4\\x11.\\xf1\\xb5\\xf8k\\xa0\\xf3\\x02h\\x12\\xbe\\x0c\\xe5\\x07\\xa8x\\xf4\\r\\x14\\x1c\\x8a\\x88C\\xe3_\\t\\xfc\\x1e}\\x19b\\x1d\\x85P\\xdb\\x08\\xdc,u\\xd3\\'\\xd6Qo:7\\xfd\\x1bH\\xf5\\xf0\\xf3\\xb1\\x91\\x8e\\xa1\\xd2\\xb1\\x8d\\xd7\\x90_\\x8c\\x0c\\x1c\\x05h\\xc2\\x91^\\x84\\x81\\xa3\\xd21\\xf0F[\\xaei\\xfe!B\\xa7\\x99\\x8emy\\x15\\xcc&\\\\S\\x13\\x8e^A\\xa8\\xfb\\xd1\\x05Xx\\xce\\xab]\\xfc\\xe8\\xe5\\xb9\\x94\\x1b5e0h\\x1c\\x06\\x8d}9^[\\xc0\\xbd\\x14\\xf1\\x84:\"X~n\\xf0\\xda#,\\x02\\x83\\x84Z2\\xb1\\x9e\\x04\\xa6 \\x92\\xea\\x1bI\\rd\\x88\\x94z2\\x03\\xd3\\xc8\\xc4\\x90\\x99\\x18\\xe2\\xab\\xc031?\\x83\\x81%2\\xb04\\x06\\x96\\xc1\\xc02\\x7fadb\\x9b\\x98/\\xba\\x80\\x8c\\xb8\\x01\\x8c\\xfc\\xfc\\xfd\\xdf$\\xc2O\\'\\x7f!5\\x9e\\xdf\\xc2\\xa0\\x08e\\x11/E,\\x9aX\\x16\\xf1R,\\x8b@\\xe3\\xb0(\\x00{\\x13\\xcb:j\\xf1e\\x11\\xf5\\xc4\\xb2\\x88\\x06\\x80L\\xc6\\x00\\x94\\xc6\\x06\\xc8\\x05\\xd0\\xf1\\n\\xe8\\r\\xa4\\x17\\xf9\\xe9\\x99\\xeb|\\x81\\x04\\xf9e\\x91\\n\"\\x80\\xff\\x03\\x99\\xa0+\\x94\\xbd\\xf0X\\xa1\\x00\\x00\\x00\\x00IEND\\xaeB`\\x82'}\n" ] } ], "source": [ "data_line = \"train/word_1.png\tGenaxis Theatre\"\n", "data_dir = \"/home/aistudio/work/train_data/ic15_data/\"\n", "\n", "item = __getitem__(data_line, data_dir)\n", "print(item)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "实现完单条数据返回逻辑后,调用 `padde.io.Dataloader` 即可把数据组合成batch,具体可参考 [build_dataloader](https://github.com/PaddlePaddle/PaddleOCR/blob/95c670faf6cf4551c841764cde43a4f4d9d5e634/ppocr/data/__init__.py#L52)。\n", "\n", "* build model\n", "\n", " build model 即搭建主要网络结构,具体细节如《2.3 代码实现》所述,本节不做过多介绍,各模块代码可参考[modeling](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.4/ppocr/modeling)\n", "\n", "* build loss\n", " \n", " CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import paddle.nn as nn\n", "class CTCLoss(nn.Layer):\n", " def __init__(self, use_focal_loss=False, **kwargs):\n", " super(CTCLoss, self).__init__()\n", " # blank 是 ctc 的无意义连接符\n", " self.loss_func = nn.CTCLoss(blank=0, reduction='none')\n", "\n", " def forward(self, predicts, batch):\n", " if isinstance(predicts, (list, tuple)):\n", " predicts = predicts[-1]\n", " # 转置模型 head 层的预测结果,沿channel层排列\n", " predicts = predicts.transpose((1, 0, 2)) #[80,1,37]\n", " N, B, _ = predicts.shape\n", " preds_lengths = paddle.to_tensor([N] * B, dtype='int64')\n", " labels = batch[1].astype(\"int32\")\n", " label_lengths = batch[2].astype('int64')\n", " # 计算损失函数\n", " loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)\n", " loss = loss.mean()\n", " return {'loss': loss}" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* build post process\n", "\n", " 具体细节同样在《2.3 代码实现》有详细介绍,实现逻辑与之前一致。" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "* build optim\n", "\n", "优化器使用 `Adam` , 同样调用飞桨API: `paddle.optimizer.Adam`\n", "\n", "* build metric\n", "\n", "metric 部分用于计算模型指标,PaddleOCR的文本识别中,将整句预测正确判断为预测正确,因此准确率计算主要逻辑如下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def metric(preds, labels): \n", " correct_num = 0\n", " all_num = 0\n", " norm_edit_dis = 0.0\n", " for (pred), (target) in zip(preds, labels):\n", " pred = pred.replace(\" \", \"\")\n", " target = target.replace(\" \", \"\")\n", " if pred == target:\n", " correct_num += 1\n", " all_num += 1\n", " correct_num += correct_num\n", " all_num += all_num\n", " return {\n", " 'acc': correct_num / all_num,\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "acc: {'acc': 0.6}\n" ] } ], "source": [ "preds = [\"aaa\", \"bbb\", \"ccc\", \"123\", \"456\"]\n", "labels = [\"aaa\", \"bbb\", \"ddd\", \"123\", \"444\"]\n", "acc = metric(preds, labels)\n", "print(\"acc:\", acc)\n", "# 五个预测结果中,完全正确的有3个,因此准确率应为0.6" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "将以上各部分组合起来,即是完整的训练流程:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\n", "def main(config, device, logger, vdl_writer):\n", " # init dist environment\n", " if config['Global']['distributed']:\n", " dist.init_parallel_env()\n", "\n", " global_config = config['Global']\n", "\n", " # build dataloader\n", " train_dataloader = build_dataloader(config, 'Train', device, logger)\n", " if len(train_dataloader) == 0:\n", " logger.error(\n", " \"No Images in train dataset, please ensure\\n\" +\n", " \"\\t1. The images num in the train label_file_list should be larger than or equal with batch size.\\n\"\n", " +\n", " \"\\t2. The annotation file and path in the configuration file are provided normally.\"\n", " )\n", " return\n", "\n", " if config['Eval']:\n", " valid_dataloader = build_dataloader(config, 'Eval', device, logger)\n", " else:\n", " valid_dataloader = None\n", "\n", " # build post process\n", " post_process_class = build_post_process(config['PostProcess'],\n", " global_config)\n", "\n", " # build model\n", " # for rec algorithm\n", " if hasattr(post_process_class, 'character'):\n", " char_num = len(getattr(post_process_class, 'character'))\n", " if config['Architecture'][\"algorithm\"] in [\"Distillation\",\n", " ]: # distillation model\n", " for key in config['Architecture'][\"Models\"]:\n", " config['Architecture'][\"Models\"][key][\"Head\"][\n", " 'out_channels'] = char_num\n", " else: # base rec model\n", " config['Architecture'][\"Head\"]['out_channels'] = char_num\n", "\n", " model = build_model(config['Architecture'])\n", " if config['Global']['distributed']:\n", " model = paddle.DataParallel(model)\n", "\n", " # build loss\n", " loss_class = build_loss(config['Loss'])\n", "\n", " # build optim\n", " optimizer, lr_scheduler = build_optimizer(\n", " config['Optimizer'],\n", " epochs=config['Global']['epoch_num'],\n", " step_each_epoch=len(train_dataloader),\n", " parameters=model.parameters())\n", "\n", " # build metric\n", " eval_class = build_metric(config['Metric'])\n", " # load pretrain model\n", " pre_best_model_dict = load_model(config, model, optimizer)\n", " logger.info('train dataloader has {} iters'.format(len(train_dataloader)))\n", " if valid_dataloader is not None:\n", " logger.info('valid dataloader has {} iters'.format(\n", " len(valid_dataloader)))\n", "\n", " use_amp = config[\"Global\"].get(\"use_amp\", False)\n", " if use_amp:\n", " AMP_RELATED_FLAGS_SETTING = {\n", " 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,\n", " 'FLAGS_max_inplace_grad_add': 8,\n", " }\n", " paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)\n", " scale_loss = config[\"Global\"].get(\"scale_loss\", 1.0)\n", " use_dynamic_loss_scaling = config[\"Global\"].get(\n", " \"use_dynamic_loss_scaling\", False)\n", " scaler = paddle.amp.GradScaler(\n", " init_loss_scaling=scale_loss,\n", " use_dynamic_loss_scaling=use_dynamic_loss_scaling)\n", " else:\n", " scaler = None\n", "\n", " # start train\n", " program.train(config, train_dataloader, valid_dataloader, device, model,\n", " loss_class, optimizer, lr_scheduler, post_process_class,\n", " eval_class, pre_best_model_dict, logger, vdl_writer, scaler)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "## 4. 完整训练任务\n", "\n", "### 4.1 启动训练\n", "\n", "PaddleOCR 识别任务与检测任务类似,是通过配置文件传输参数的。\n", "\n", "要进行完整的模型训练,首先需要下载整个项目并安装相关依赖:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n", "Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)\n", "Collecting scikit-image==0.17.2\n", " Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d7/ee/753ea56fda5bc2a5516a1becb631bf5ada593a2dd44f21971a13a762d4db/scikit_image-0.17.2-cp37-cp37m-manylinux1_x86_64.whl (12.5 MB)\n", " |████████████████████████████████| 12.5 MB 8.4 MB/s \n", "\u001b[?25hRequirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)\n", "Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)\n", "Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)\n", "Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.36.1)\n", "Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)\n", "Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)\n", "Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)\n", "Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)\n", "Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)\n", "Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)\n", "Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)\n", "Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)\n", "Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)\n", "Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)\n", "Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)\n", "Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)\n", "Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)\n", "Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)\n", "Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)\n", "Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)\n", "Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)\n", "Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)\n", "Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)\n", "Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)\n", "Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)\n", "Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)\n", "Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)\n", "Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)\n", "Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)\n", "Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)\n", "Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)\n", "Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)\n", "Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)\n", "Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)\n", "Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)\n", "Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)\n", "Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)\n", "Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)\n", "Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)\n", "Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)\n", "Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)\n", "Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)\n", "Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)\n", "Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)\n", "Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)\n", "Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)\n", "Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)\n", "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)\n", "Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)\n", "Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)\n", "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)\n", "Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)\n", "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)\n", "Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)\n", "Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)\n", "Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)\n", "Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)\n", "Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)\n", "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)\n", "Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)\n", "Installing collected packages: scikit-image\n", " Attempting uninstall: scikit-image\n", " Found existing installation: scikit-image 0.19.1\n", " Uninstalling scikit-image-0.19.1:\n", " Successfully uninstalled scikit-image-0.19.1\n", "Successfully installed scikit-image-0.17.2\n" ] } ], "source": [ "# 克隆PaddleOCR代码\n", "#!git clone https://gitee.com/paddlepaddle/PaddleOCR\n", "# 修改代码运行的默认目录为 /home/aistudio/PaddleOCR\n", "import os\n", "os.chdir(\"/home/aistudio/PaddleOCR\")\n", "# 安装PaddleOCR第三方依赖\n", "!pip install -r requirements.txt" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "创建软链,将训练数据放在PaddleOCR项目下:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "!ln -s /home/aistudio/work/train_data/ /home/aistudio/PaddleOCR/" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "下载预训练模型:\n", "\n", "为了加快收敛速度,建议下载训练好的模型在 icdar2015 数据上进行 finetune" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--2021-12-22 15:39:39-- https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar\n", "Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n", "Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 51200000 (49M) [application/x-tar]\n", "Saving to: ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’\n", "\n", "rec_mv3_none_bilstm 100%[===================>] 48.83M 15.5MB/s in 3.6s \n", "\n", "2021-12-22 15:39:42 (13.7 MB/s) - ‘./pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar’ saved [51200000/51200000]\n", "\n" ] } ], "source": [ "!cd PaddleOCR/\n", "# 下载MobileNetV3的预训练模型\n", "!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar\n", "# 解压模型参数\n", "!tar -xf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf pretrain_models/rec_mv3_none_bilstm_ctc_v2.0_train.tar" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "启动训练命令很简单,指定好配置文件即可。另外在命令行中可以通过 `-o` 修改配置文件中的参数值。启动训练命令如下所示\n", "\n", "其中:\n", "\n", "* `Global.pretrained_model`: 加载的预训练模型路径\n", "* `Global.character_dict_path` : 字典路径(这里只支持26个小写字母+数字)\n", "* `Global.eval_batch_step` : 评估频率\n", "* `Global.epoch_num`: 总训练轮数\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n", "/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n", "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n", "[2021/12/23 20:28:15] root INFO: Architecture : \n", "[2021/12/23 20:28:15] root INFO: Backbone : \n", "[2021/12/23 20:28:15] root INFO: model_name : large\n", "[2021/12/23 20:28:15] root INFO: name : MobileNetV3\n", "[2021/12/23 20:28:15] root INFO: scale : 0.5\n", "[2021/12/23 20:28:15] root INFO: Head : \n", "[2021/12/23 20:28:15] root INFO: fc_decay : 0\n", "[2021/12/23 20:28:15] root INFO: name : CTCHead\n", "[2021/12/23 20:28:15] root INFO: Neck : \n", "[2021/12/23 20:28:15] root INFO: encoder_type : rnn\n", "[2021/12/23 20:28:15] root INFO: hidden_size : 96\n", "[2021/12/23 20:28:15] root INFO: name : SequenceEncoder\n", "[2021/12/23 20:28:15] root INFO: Transform : None\n", "[2021/12/23 20:28:15] root INFO: algorithm : CRNN\n", "[2021/12/23 20:28:15] root INFO: model_type : rec\n", "[2021/12/23 20:28:15] root INFO: Eval : \n", "[2021/12/23 20:28:15] root INFO: dataset : \n", "[2021/12/23 20:28:15] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 20:28:15] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 20:28:15] root INFO: name : SimpleDataSet\n", "[2021/12/23 20:28:15] root INFO: transforms : \n", "[2021/12/23 20:28:15] root INFO: DecodeImage : \n", "[2021/12/23 20:28:15] root INFO: channel_first : False\n", "[2021/12/23 20:28:15] root INFO: img_mode : BGR\n", "[2021/12/23 20:28:15] root INFO: CTCLabelEncode : None\n", "[2021/12/23 20:28:15] root INFO: RecResizeImg : \n", "[2021/12/23 20:28:15] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 20:28:15] root INFO: KeepKeys : \n", "[2021/12/23 20:28:15] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 20:28:15] root INFO: loader : \n", "[2021/12/23 20:28:15] root INFO: batch_size_per_card : 256\n", "[2021/12/23 20:28:15] root INFO: drop_last : False\n", "[2021/12/23 20:28:15] root INFO: num_workers : 4\n", "[2021/12/23 20:28:15] root INFO: shuffle : False\n", "[2021/12/23 20:28:15] root INFO: use_shared_memory : False\n", "[2021/12/23 20:28:15] root INFO: Global : \n", "[2021/12/23 20:28:15] root INFO: cal_metric_during_train : True\n", "[2021/12/23 20:28:15] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 20:28:15] root INFO: character_type : EN\n", "[2021/12/23 20:28:15] root INFO: checkpoints : None\n", "[2021/12/23 20:28:15] root INFO: debug : False\n", "[2021/12/23 20:28:15] root INFO: distributed : False\n", "[2021/12/23 20:28:15] root INFO: epoch_num : 40\n", "[2021/12/23 20:28:15] root INFO: eval_batch_step : [0, 200]\n", "[2021/12/23 20:28:15] root INFO: infer_img : doc/imgs_words_en/word_19.png\n", "[2021/12/23 20:28:15] root INFO: infer_mode : False\n", "[2021/12/23 20:28:15] root INFO: log_smooth_window : 20\n", "[2021/12/23 20:28:15] root INFO: max_text_length : 25\n", "[2021/12/23 20:28:15] root INFO: pretrained_model : rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy\n", "[2021/12/23 20:28:15] root INFO: print_batch_step : 10\n", "[2021/12/23 20:28:15] root INFO: save_epoch_step : 3\n", "[2021/12/23 20:28:15] root INFO: save_inference_dir : ./\n", "[2021/12/23 20:28:15] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 20:28:15] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 20:28:15] root INFO: use_gpu : True\n", "[2021/12/23 20:28:15] root INFO: use_space_char : False\n", "[2021/12/23 20:28:15] root INFO: use_visualdl : False\n", "[2021/12/23 20:28:15] root INFO: Loss : \n", "[2021/12/23 20:28:15] root INFO: name : CTCLoss\n", "[2021/12/23 20:28:15] root INFO: Metric : \n", "[2021/12/23 20:28:15] root INFO: main_indicator : acc\n", "[2021/12/23 20:28:15] root INFO: name : RecMetric\n", "[2021/12/23 20:28:15] root INFO: Optimizer : \n", "[2021/12/23 20:28:15] root INFO: beta1 : 0.9\n", "[2021/12/23 20:28:15] root INFO: beta2 : 0.999\n", "[2021/12/23 20:28:15] root INFO: lr : \n", "[2021/12/23 20:28:15] root INFO: learning_rate : 0.0005\n", "[2021/12/23 20:28:15] root INFO: name : Adam\n", "[2021/12/23 20:28:15] root INFO: regularizer : \n", "[2021/12/23 20:28:15] root INFO: factor : 0\n", "[2021/12/23 20:28:15] root INFO: name : L2\n", "[2021/12/23 20:28:15] root INFO: PostProcess : \n", "[2021/12/23 20:28:15] root INFO: name : CTCLabelDecode\n", "[2021/12/23 20:28:15] root INFO: Train : \n", "[2021/12/23 20:28:15] root INFO: dataset : \n", "[2021/12/23 20:28:15] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 20:28:15] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:15] root INFO: name : SimpleDataSet\n", "[2021/12/23 20:28:15] root INFO: transforms : \n", "[2021/12/23 20:28:15] root INFO: DecodeImage : \n", "[2021/12/23 20:28:15] root INFO: channel_first : False\n", "[2021/12/23 20:28:15] root INFO: img_mode : BGR\n", "[2021/12/23 20:28:15] root INFO: CTCLabelEncode : None\n", "[2021/12/23 20:28:15] root INFO: RecResizeImg : \n", "[2021/12/23 20:28:15] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 20:28:15] root INFO: KeepKeys : \n", "[2021/12/23 20:28:15] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 20:28:15] root INFO: loader : \n", "[2021/12/23 20:28:15] root INFO: batch_size_per_card : 256\n", "[2021/12/23 20:28:15] root INFO: drop_last : True\n", "[2021/12/23 20:28:15] root INFO: num_workers : 8\n", "[2021/12/23 20:28:15] root INFO: shuffle : True\n", "[2021/12/23 20:28:15] root INFO: use_shared_memory : False\n", "[2021/12/23 20:28:15] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "[2021/12/23 20:28:15] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:15] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']\n", "W1223 20:28:15.851713 306 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n", "W1223 20:28:15.857080 306 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 20:28:19] root INFO: loaded pretrained_model successful from rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy.pdparams\n", "[2021/12/23 20:28:19] root INFO: train dataloader has 17 iters\n", "[2021/12/23 20:28:19] root INFO: valid dataloader has 9 iters\n", "[2021/12/23 20:28:19] root INFO: During the training process, after the 0th iteration, an evaluation is run every 200 iterations\n", "[2021/12/23 20:28:19] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:23] root INFO: epoch: [1/40], iter: 10, lr: 0.000500, loss: 9.336592, acc: 0.203125, norm_edit_dis: 0.674909, reader_cost: 0.27284 s, batch_cost: 0.40185 s, samples: 2816, ips: 700.75290\n", "[2021/12/23 20:28:24] root INFO: epoch: [1/40], iter: 16, lr: 0.000500, loss: 6.955496, acc: 0.210938, norm_edit_dis: 0.678930, reader_cost: 0.00008 s, batch_cost: 0.05430 s, samples: 1536, ips: 2828.80514\n", "[2021/12/23 20:28:24] root INFO: save model in ./output/rec/ic15/latest\n", "[2021/12/23 20:28:24] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 20:28:28] root INFO: epoch: [2/40], iter: 20, lr: 0.000500, loss: 6.402417, acc: 0.246094, norm_edit_dis: 0.695874, reader_cost: 0.24180 s, batch_cost: 0.34361 s, samples: 1024, ips: 298.00945\n", "[2021/12/23 20:28:29] root INFO: epoch: [2/40], iter: 30, lr: 0.000500, loss: 4.007382, acc: 0.412109, norm_edit_dis: 0.743064, reader_cost: 0.00013 s, batch_cost: 0.08982 s, samples: 2560, ips: 2849.98954\n", "[2021/12/23 20:28:29] root INFO: epoch: [2/40], iter: 33, lr: 0.000500, loss: 3.906031, acc: 0.458984, norm_edit_dis: 0.770415, reader_cost: 0.00004 s, batch_cost: 0.02684 s, samples: 768, ips: 2861.80304\n", "^C\n", "main proc 306 exit, kill process group 306\n" ] } ], "source": [ "!python3 tools/train.py -c configs/rec/rec_icdar15_train.yml \\\n", " -o Global.pretrained_model=rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy \\\n", " Global.character_dict_path=ppocr/utils/ic15_dict.txt \\\n", " Global.eval_batch_step=[0,200] \\\n", " Global.epoch_num=40" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "根据配置文件中设置的的 `save_model_dir` 字段,会有以下几种参数被保存下来:\n", "\n", "```\n", "output/rec/ic15\n", "├── best_accuracy.pdopt \n", "├── best_accuracy.pdparams \n", "├── best_accuracy.states \n", "├── config.yml \n", "├── iter_epoch_3.pdopt \n", "├── iter_epoch_3.pdparams \n", "├── iter_epoch_3.states \n", "├── latest.pdopt \n", "├── latest.pdparams \n", "├── latest.states \n", "└── train.log\n", "```\n", "其中 best_accuracy.* 是评估集上的最优模型;iter_epoch_x.* 是以 `save_epoch_step` 为间隔保存下来的模型;latest.* 是最后一个epoch的模型。\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "**总结:**\n", "\n", "如果需要训练自己的数据需要修改:\n", "\n", "1. 训练和评估数据路径(必须)\n", "2. 字典路径(必须)\n", "3. 预训练模型 (可选)\n", "4. 学习率、image shape、网络结构(可选)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 4.2 模型评估\n", "\n", "\n", "评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。\n", "\n", "这里默认使用 icdar2015 的评估集,加载刚刚训练好的模型权重:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 14:27:51] root INFO: Architecture : \n", "[2021/12/23 14:27:51] root INFO: Backbone : \n", "[2021/12/23 14:27:51] root INFO: model_name : large\n", "[2021/12/23 14:27:51] root INFO: name : MobileNetV3\n", "[2021/12/23 14:27:51] root INFO: scale : 0.5\n", "[2021/12/23 14:27:51] root INFO: Head : \n", "[2021/12/23 14:27:51] root INFO: fc_decay : 0\n", "[2021/12/23 14:27:51] root INFO: name : CTCHead\n", "[2021/12/23 14:27:51] root INFO: Neck : \n", "[2021/12/23 14:27:51] root INFO: encoder_type : rnn\n", "[2021/12/23 14:27:51] root INFO: hidden_size : 96\n", "[2021/12/23 14:27:51] root INFO: name : SequenceEncoder\n", "[2021/12/23 14:27:51] root INFO: Transform : None\n", "[2021/12/23 14:27:51] root INFO: algorithm : CRNN\n", "[2021/12/23 14:27:51] root INFO: model_type : rec\n", "[2021/12/23 14:27:51] root INFO: Eval : \n", "[2021/12/23 14:27:51] root INFO: dataset : \n", "[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 14:27:51] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:27:51] root INFO: transforms : \n", "[2021/12/23 14:27:51] root INFO: DecodeImage : \n", "[2021/12/23 14:27:51] root INFO: channel_first : False\n", "[2021/12/23 14:27:51] root INFO: img_mode : BGR\n", "[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:27:51] root INFO: RecResizeImg : \n", "[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:27:51] root INFO: KeepKeys : \n", "[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:27:51] root INFO: loader : \n", "[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:27:51] root INFO: drop_last : False\n", "[2021/12/23 14:27:51] root INFO: num_workers : 4\n", "[2021/12/23 14:27:51] root INFO: shuffle : False\n", "[2021/12/23 14:27:51] root INFO: use_shared_memory : False\n", "[2021/12/23 14:27:51] root INFO: Global : \n", "[2021/12/23 14:27:51] root INFO: cal_metric_during_train : True\n", "[2021/12/23 14:27:51] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 14:27:51] root INFO: character_type : EN\n", "[2021/12/23 14:27:51] root INFO: checkpoints : output/rec/ic15/best_accuracy\n", "[2021/12/23 14:27:51] root INFO: debug : False\n", "[2021/12/23 14:27:51] root INFO: distributed : False\n", "[2021/12/23 14:27:51] root INFO: epoch_num : 72\n", "[2021/12/23 14:27:51] root INFO: eval_batch_step : [0, 2000]\n", "[2021/12/23 14:27:51] root INFO: infer_img : doc/imgs_words_en/word_10.png\n", "[2021/12/23 14:27:51] root INFO: infer_mode : False\n", "[2021/12/23 14:27:51] root INFO: log_smooth_window : 20\n", "[2021/12/23 14:27:51] root INFO: max_text_length : 25\n", "[2021/12/23 14:27:51] root INFO: pretrained_model : None\n", "[2021/12/23 14:27:51] root INFO: print_batch_step : 10\n", "[2021/12/23 14:27:51] root INFO: save_epoch_step : 3\n", "[2021/12/23 14:27:51] root INFO: save_inference_dir : ./\n", "[2021/12/23 14:27:51] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 14:27:51] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 14:27:51] root INFO: use_gpu : True\n", "[2021/12/23 14:27:51] root INFO: use_space_char : False\n", "[2021/12/23 14:27:51] root INFO: use_visualdl : False\n", "[2021/12/23 14:27:51] root INFO: Loss : \n", "[2021/12/23 14:27:51] root INFO: name : CTCLoss\n", "[2021/12/23 14:27:51] root INFO: Metric : \n", "[2021/12/23 14:27:51] root INFO: main_indicator : acc\n", "[2021/12/23 14:27:51] root INFO: name : RecMetric\n", "[2021/12/23 14:27:51] root INFO: Optimizer : \n", "[2021/12/23 14:27:51] root INFO: beta1 : 0.9\n", "[2021/12/23 14:27:51] root INFO: beta2 : 0.999\n", "[2021/12/23 14:27:51] root INFO: lr : \n", "[2021/12/23 14:27:51] root INFO: learning_rate : 0.0005\n", "[2021/12/23 14:27:51] root INFO: name : Adam\n", "[2021/12/23 14:27:51] root INFO: regularizer : \n", "[2021/12/23 14:27:51] root INFO: factor : 0\n", "[2021/12/23 14:27:51] root INFO: name : L2\n", "[2021/12/23 14:27:51] root INFO: PostProcess : \n", "[2021/12/23 14:27:51] root INFO: name : CTCLabelDecode\n", "[2021/12/23 14:27:51] root INFO: Train : \n", "[2021/12/23 14:27:51] root INFO: dataset : \n", "[2021/12/23 14:27:51] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 14:27:51] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 14:27:51] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:27:51] root INFO: transforms : \n", "[2021/12/23 14:27:51] root INFO: DecodeImage : \n", "[2021/12/23 14:27:51] root INFO: channel_first : False\n", "[2021/12/23 14:27:51] root INFO: img_mode : BGR\n", "[2021/12/23 14:27:51] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:27:51] root INFO: RecResizeImg : \n", "[2021/12/23 14:27:51] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:27:51] root INFO: KeepKeys : \n", "[2021/12/23 14:27:51] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:27:51] root INFO: loader : \n", "[2021/12/23 14:27:51] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:27:51] root INFO: drop_last : True\n", "[2021/12/23 14:27:51] root INFO: num_workers : 8\n", "[2021/12/23 14:27:51] root INFO: shuffle : True\n", "[2021/12/23 14:27:51] root INFO: use_shared_memory : False\n", "[2021/12/23 14:27:51] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "[2021/12/23 14:27:51] root INFO: Initialize indexs of datasets:['./train_data/ic15_data/rec_gt_test.txt']\n", "W1223 14:27:51.861889 5192 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", "W1223 14:27:51.865501 5192 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 14:27:56] root INFO: resume from output/rec/ic15/best_accuracy\n", "[2021/12/23 14:27:56] root INFO: metric in ckpt ***************\n", "[2021/12/23 14:27:56] root INFO: acc:0.48531535869041886\n", "[2021/12/23 14:27:56] root INFO: norm_edit_dis:0.7895228681338454\n", "[2021/12/23 14:27:56] root INFO: fps:3266.1877400927865\n", "[2021/12/23 14:27:56] root INFO: best_epoch:24\n", "[2021/12/23 14:27:56] root INFO: start_epoch:25\n", "eval model:: 100%|████████████████████████████████| 9/9 [00:02<00:00, 3.32it/s]\n", "[2021/12/23 14:27:59] root INFO: metric eval ***************\n", "[2021/12/23 14:27:59] root INFO: acc:0.48531535869041886\n", "[2021/12/23 14:27:59] root INFO: norm_edit_dis:0.7895228681338454\n", "[2021/12/23 14:27:59] root INFO: fps:4491.015930181665\n" ] } ], "source": [ "!python tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy \\\n", " Global.character_dict_path=ppocr/utils/ic15_dict.txt\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "评估后,可以看到训练模型在验证集上的精度。\n", "\n", "PaddleOCR支持训练和评估交替进行, 可在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每2000个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec/ic15/best_accuracy` 。\n", "\n", "如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 4.3 预测\n", "\n", "使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。\n", "\n", "预测图片:\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.3/doc/imgs_words_en/word_19.png)\n", "\n", "默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[2021/12/23 14:29:19] root INFO: Architecture : \n", "[2021/12/23 14:29:19] root INFO: Backbone : \n", "[2021/12/23 14:29:19] root INFO: model_name : large\n", "[2021/12/23 14:29:19] root INFO: name : MobileNetV3\n", "[2021/12/23 14:29:19] root INFO: scale : 0.5\n", "[2021/12/23 14:29:19] root INFO: Head : \n", "[2021/12/23 14:29:19] root INFO: fc_decay : 0\n", "[2021/12/23 14:29:19] root INFO: name : CTCHead\n", "[2021/12/23 14:29:19] root INFO: Neck : \n", "[2021/12/23 14:29:19] root INFO: encoder_type : rnn\n", "[2021/12/23 14:29:19] root INFO: hidden_size : 96\n", "[2021/12/23 14:29:19] root INFO: name : SequenceEncoder\n", "[2021/12/23 14:29:19] root INFO: Transform : None\n", "[2021/12/23 14:29:19] root INFO: algorithm : CRNN\n", "[2021/12/23 14:29:19] root INFO: model_type : rec\n", "[2021/12/23 14:29:19] root INFO: Eval : \n", "[2021/12/23 14:29:19] root INFO: dataset : \n", "[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data\n", "[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_test.txt']\n", "[2021/12/23 14:29:19] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:29:19] root INFO: transforms : \n", "[2021/12/23 14:29:19] root INFO: DecodeImage : \n", "[2021/12/23 14:29:19] root INFO: channel_first : False\n", "[2021/12/23 14:29:19] root INFO: img_mode : BGR\n", "[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:29:19] root INFO: RecResizeImg : \n", "[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:29:19] root INFO: KeepKeys : \n", "[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:29:19] root INFO: loader : \n", "[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:29:19] root INFO: drop_last : False\n", "[2021/12/23 14:29:19] root INFO: num_workers : 4\n", "[2021/12/23 14:29:19] root INFO: shuffle : False\n", "[2021/12/23 14:29:19] root INFO: use_shared_memory : False\n", "[2021/12/23 14:29:19] root INFO: Global : \n", "[2021/12/23 14:29:19] root INFO: cal_metric_during_train : True\n", "[2021/12/23 14:29:19] root INFO: character_dict_path : ppocr/utils/ic15_dict.txt\n", "[2021/12/23 14:29:19] root INFO: character_type : EN\n", "[2021/12/23 14:29:19] root INFO: checkpoints : output/rec/ic15/best_accuracy\n", "[2021/12/23 14:29:19] root INFO: debug : False\n", "[2021/12/23 14:29:19] root INFO: distributed : False\n", "[2021/12/23 14:29:19] root INFO: epoch_num : 72\n", "[2021/12/23 14:29:19] root INFO: eval_batch_step : [0, 2000]\n", "[2021/12/23 14:29:19] root INFO: infer_img : doc/imgs_words_en/word_19.png\n", "[2021/12/23 14:29:19] root INFO: infer_mode : False\n", "[2021/12/23 14:29:19] root INFO: log_smooth_window : 20\n", "[2021/12/23 14:29:19] root INFO: max_text_length : 25\n", "[2021/12/23 14:29:19] root INFO: pretrained_model : None\n", "[2021/12/23 14:29:19] root INFO: print_batch_step : 10\n", "[2021/12/23 14:29:19] root INFO: save_epoch_step : 3\n", "[2021/12/23 14:29:19] root INFO: save_inference_dir : ./\n", "[2021/12/23 14:29:19] root INFO: save_model_dir : ./output/rec/ic15/\n", "[2021/12/23 14:29:19] root INFO: save_res_path : ./output/rec/predicts_ic15.txt\n", "[2021/12/23 14:29:19] root INFO: use_gpu : True\n", "[2021/12/23 14:29:19] root INFO: use_space_char : False\n", "[2021/12/23 14:29:19] root INFO: use_visualdl : False\n", "[2021/12/23 14:29:19] root INFO: Loss : \n", "[2021/12/23 14:29:19] root INFO: name : CTCLoss\n", "[2021/12/23 14:29:19] root INFO: Metric : \n", "[2021/12/23 14:29:19] root INFO: main_indicator : acc\n", "[2021/12/23 14:29:19] root INFO: name : RecMetric\n", "[2021/12/23 14:29:19] root INFO: Optimizer : \n", "[2021/12/23 14:29:19] root INFO: beta1 : 0.9\n", "[2021/12/23 14:29:19] root INFO: beta2 : 0.999\n", "[2021/12/23 14:29:19] root INFO: lr : \n", "[2021/12/23 14:29:19] root INFO: learning_rate : 0.0005\n", "[2021/12/23 14:29:19] root INFO: name : Adam\n", "[2021/12/23 14:29:19] root INFO: regularizer : \n", "[2021/12/23 14:29:19] root INFO: factor : 0\n", "[2021/12/23 14:29:19] root INFO: name : L2\n", "[2021/12/23 14:29:19] root INFO: PostProcess : \n", "[2021/12/23 14:29:19] root INFO: name : CTCLabelDecode\n", "[2021/12/23 14:29:19] root INFO: Train : \n", "[2021/12/23 14:29:19] root INFO: dataset : \n", "[2021/12/23 14:29:19] root INFO: data_dir : ./train_data/ic15_data/\n", "[2021/12/23 14:29:19] root INFO: label_file_list : ['./train_data/ic15_data/rec_gt_train.txt']\n", "[2021/12/23 14:29:19] root INFO: name : SimpleDataSet\n", "[2021/12/23 14:29:19] root INFO: transforms : \n", "[2021/12/23 14:29:19] root INFO: DecodeImage : \n", "[2021/12/23 14:29:19] root INFO: channel_first : False\n", "[2021/12/23 14:29:19] root INFO: img_mode : BGR\n", "[2021/12/23 14:29:19] root INFO: CTCLabelEncode : None\n", "[2021/12/23 14:29:19] root INFO: RecResizeImg : \n", "[2021/12/23 14:29:19] root INFO: image_shape : [3, 32, 100]\n", "[2021/12/23 14:29:19] root INFO: KeepKeys : \n", "[2021/12/23 14:29:19] root INFO: keep_keys : ['image', 'label', 'length']\n", "[2021/12/23 14:29:19] root INFO: loader : \n", "[2021/12/23 14:29:19] root INFO: batch_size_per_card : 256\n", "[2021/12/23 14:29:19] root INFO: drop_last : True\n", "[2021/12/23 14:29:19] root INFO: num_workers : 8\n", "[2021/12/23 14:29:19] root INFO: shuffle : True\n", "[2021/12/23 14:29:19] root INFO: use_shared_memory : False\n", "[2021/12/23 14:29:19] root INFO: train with paddle 2.1.2 and device CUDAPlace(0)\n", "W1223 14:29:19.803710 5290 device_context.cc:404] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", "W1223 14:29:19.807695 5290 device_context.cc:422] device: 0, cuDNN Version: 7.6.\n", "[2021/12/23 14:29:25] root INFO: resume from output/rec/ic15/best_accuracy\n", "[2021/12/23 14:29:25] root INFO: infer_img: doc/imgs_words_en/word_19.png\n", "pred idx: Tensor(shape=[1, 25], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n", " [[29, 0 , 0 , 0 , 22, 0 , 0 , 0 , 25, 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 33]])\n", "[2021/12/23 14:29:25] root INFO: \t result: slow\t0.8795223\n", "[2021/12/23 14:29:25] root INFO: success!\n" ] } ], "source": [ "!python tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints=output/rec/ic15/best_accuracy Global.character_dict_path=ppocr/utils/ic15_dict.txt" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "得到输入图像的预测结果:\n", "\n", "```\n", "infer_img: doc/imgs_words_en/word_19.png\n", " result: slow\t0.8795223\n", "```\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ "### 作业\n", "\n", "**【题目1】**\n", "\n", "可视化出 PaddleOCR 中的实现的[数据增强](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/rec_img_aug.py)结果:noise、jitter, 并用语言解释效果。\n", "\n", "可选测试图片:\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_1.jpg)\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_2.jpg)\n", "\n", "![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.4/doc/imgs_words/ch/word_3.jpg)\n", "\n", "\n", "**【题目2】**\n", "\n", "更换 configs/rec/rec_icdar15_train.yml 配置中的 backbone 为 PaddleOCR 中的 [ResNet34_vd](https://github.com/PaddlePaddle/PaddleOCR/blob/6ee301be36eb54d91dc437842f754593dce13967/ppocr/modeling/backbones/rec_resnet_vd.py#L176),当输入图片shape为(3,32,100)时,Head 层最终输出的特征尺寸是多少?\n", "\n", "\n", "**【题目3】**\n", "\n", "下载10W中文数据集[rec_data_lesson_demo](https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar),修改 configs/rec/rec_icdar15_train.yml 配置文件训练一个识别模型,提供训练log。\n", "\n", "可加载预训练模型: https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar \n", "\n", "\n", "## 总结\n", "\n", "至此,一个基于CRNN的文本识别任务就全部完成了,更多功能和代码可以参考 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)。\n", "\n", "如果对项目任何问题或者疑问,欢迎在评论区留言提出" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "py35-paddle1.2.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }