未验证 提交 cf613b13 编写于 作者: B Bin Lu 提交者: GitHub

Merge branch 'PaddlePaddle:dygraph' into dygraph

此差异已折叠。
...@@ -8,6 +8,8 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w ...@@ -8,6 +8,8 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
### Recent Update ### Recent Update
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- Improve user experience: prompt for the number of files and labels, optimize interaction, and fix bugs such as only use CPU when inference
- 2021.11.17: - 2021.11.17:
- Support install and start PPOCRLabel through the whl package (by [d2623587501](https://github.com/d2623587501)) - Support install and start PPOCRLabel through the whl package (by [d2623587501](https://github.com/d2623587501))
- Dataset segmentation: Divide the annotation file into training, verification and testing parts (refer to section 3.5 below, by [MrCuiHao](https://github.com/MrCuiHao)) - Dataset segmentation: Divide the annotation file into training, verification and testing parts (refer to section 3.5 below, by [MrCuiHao](https://github.com/MrCuiHao))
...@@ -110,7 +112,7 @@ python PPOCRLabel.py ...@@ -110,7 +112,7 @@ python PPOCRLabel.py
6. Click 're-Recognition', model will rewrite ALL recognition results in ALL detection box<sup>[3]</sup>. 6. Click 're-Recognition', model will rewrite ALL recognition results in ALL detection box<sup>[3]</sup>.
7. Double click the result in 'recognition result' list to manually change inaccurate recognition results. 7. Single click the result in 'recognition result' list to manually change inaccurate recognition results.
8. **Click "Check", the image status will switch to "√",then the program automatically jump to the next.** 8. **Click "Check", the image status will switch to "√",then the program automatically jump to the next.**
...@@ -143,15 +145,17 @@ python PPOCRLabel.py ...@@ -143,15 +145,17 @@ python PPOCRLabel.py
### 3.1 Shortcut keys ### 3.1 Shortcut keys
| Shortcut keys | Description | | Shortcut keys | Description |
| ------------------------ | ------------------------------------------------ | |--------------------------|--------------------------------------------------|
| Ctrl + Shift + R | Re-recognize all the labels of the current image | | Ctrl + Shift + R | Re-recognize all the labels of the current image |
| W | Create a rect box | | W | Create a rect box |
| Q | Create a four-points box | | Q | Create a four-points box |
| X | Rotate the box anti-clockwise |
| C | Rotate the box clockwise |
| Ctrl + E | Edit label of the selected box | | Ctrl + E | Edit label of the selected box |
| Ctrl + R | Re-recognize the selected box | | Ctrl + R | Re-recognize the selected box |
| Ctrl + C | Copy and paste the selected box | | Ctrl + C | Copy and paste the selected box |
| Ctrl + Left Mouse Button | Multi select the label box | | Ctrl + Left Mouse Button | Multi select the label box |
| Backspace | Delete the selected box | | Alt + X | Delete the selected box |
| Ctrl + V | Check image | | Ctrl + V | Check image |
| Ctrl + Shift + d | Delete image | | Ctrl + Shift + d | Delete image |
| D | Next image | | D | Next image |
......
...@@ -8,6 +8,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P ...@@ -8,6 +8,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
#### 近期更新 #### 近期更新
- 2022.01:(by [PeterH0323](https://github.com/peterh0323)
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题
- 2021.11.17: - 2021.11.17:
- 新增支持通过whl包安装和启动PPOCRLabel(by [d2623587501](https://github.com/d2623587501) - 新增支持通过whl包安装和启动PPOCRLabel(by [d2623587501](https://github.com/d2623587501)
- 标注数据集切分:对标注数据进行训练、验证与测试集划分(参考下方3.5节,by [MrCuiHao](https://github.com/MrCuiHao) - 标注数据集切分:对标注数据进行训练、验证与测试集划分(参考下方3.5节,by [MrCuiHao](https://github.com/MrCuiHao)
...@@ -102,7 +104,7 @@ python PPOCRLabel.py --lang ch ...@@ -102,7 +104,7 @@ python PPOCRLabel.py --lang ch
4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。 4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。 5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup> 6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>
7. 内容更改:击识别结果,对不准确的识别结果进行手动更改。 7. 内容更改:击识别结果,对不准确的识别结果进行手动更改。
8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。** 8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。**
9. 删除:点击 “删除图像”,图片将会被删除至回收站。 9. 删除:点击 “删除图像”,图片将会被删除至回收站。
10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup> 10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup>
...@@ -132,15 +134,17 @@ python PPOCRLabel.py --lang ch ...@@ -132,15 +134,17 @@ python PPOCRLabel.py --lang ch
### 3.1 快捷键 ### 3.1 快捷键
| 快捷键 | 说明 | | 快捷键 | 说明 |
| ---------------- | ---------------------------- | |------------------|----------------|
| Ctrl + shift + R | 对当前图片的所有标记重新识别 | | Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 | | W | 新建矩形框 |
| Q | 新建四点框 | | Q | 新建四点框 |
| X | 框逆时针旋转 |
| C | 框顺时针旋转 |
| Ctrl + E | 编辑所选框标签 | | Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 | | Ctrl + R | 重新识别所选标记 |
| Ctrl + C | 复制并粘贴选中的标记框 | | Ctrl + C | 复制并粘贴选中的标记框 |
| Ctrl + 鼠标左键 | 多选标记框 | | Ctrl + 鼠标左键 | 多选标记框 |
| Backspace | 删除所选框 | | Alt + X | 删除所选框 |
| Ctrl + V | 确认本张图片标记 | | Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 | | Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 | | D | 下一张图片 |
......
# Copyright (c) <2015-Present> Tzutalin
# Copyright (C) 2013 MIT, Computer Science and Artificial Intelligence Laboratory. Bryan Russell, Antonio Torralba,
# William T. Freeman. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction, including without
# limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
# the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import sys
try:
from PyQt5.QtWidgets import QWidget, QHBoxLayout, QComboBox
except ImportError:
# needed for py3+qt4
# Ref:
# http://pyqt.sourceforge.net/Docs/PyQt4/incompatible_apis.html
# http://stackoverflow.com/questions/21217399/pyqt4-qtcore-qvariant-object-instead-of-a-string
if sys.version_info.major >= 3:
import sip
sip.setapi('QVariant', 2)
from PyQt4.QtGui import QWidget, QHBoxLayout, QComboBox
class ComboBox(QWidget):
def __init__(self, parent=None, items=[]):
super(ComboBox, self).__init__(parent)
layout = QHBoxLayout()
self.cb = QComboBox()
self.items = items
self.cb.addItems(self.items)
self.cb.currentIndexChanged.connect(parent.comboSelectionChanged)
layout.addWidget(self.cb)
self.setLayout(layout)
def update_items(self, items):
self.items = items
self.cb.clear()
self.cb.addItems(self.items)
...@@ -6,6 +6,8 @@ except ImportError: ...@@ -6,6 +6,8 @@ except ImportError:
from PyQt4.QtGui import * from PyQt4.QtGui import *
from PyQt4.QtCore import * from PyQt4.QtCore import *
import time
import datetime
import json import json
import cv2 import cv2
import numpy as np import numpy as np
...@@ -80,8 +82,9 @@ class AutoDialog(QDialog): ...@@ -80,8 +82,9 @@ class AutoDialog(QDialog):
self.parent = parent self.parent = parent
self.ocr = ocr self.ocr = ocr
self.mImgList = mImgList self.mImgList = mImgList
self.lender = lenbar
self.pb = QProgressBar() self.pb = QProgressBar()
self.pb.setRange(0, lenbar) self.pb.setRange(0, self.lender)
self.pb.setValue(0) self.pb.setValue(0)
layout = QVBoxLayout() layout = QVBoxLayout()
...@@ -108,10 +111,16 @@ class AutoDialog(QDialog): ...@@ -108,10 +111,16 @@ class AutoDialog(QDialog):
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal) self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal) self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal) self.thread_1.endsignal.connect(self.handleEndsignalSignal)
self.time_start = time.time() # save start time
def handleProgressBarSingal(self, i): def handleProgressBarSingal(self, i):
self.pb.setValue(i) self.pb.setValue(i)
# calculate time left of auto labeling
avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show
def handleListWidgetSingal(self, i): def handleListWidgetSingal(self, i):
self.listWidget.addItem(i) self.listWidget.addItem(i)
titem = self.listWidget.item(self.listWidget.count() - 1) titem = self.listWidget.item(self.listWidget.count() - 1)
......
...@@ -11,19 +11,13 @@ ...@@ -11,19 +11,13 @@
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
try: import copy
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
#from PyQt4.QtOpenGL import *
from PyQt5.QtCore import Qt, pyqtSignal, QPointF, QPoint
from PyQt5.QtGui import QPainter, QBrush, QColor, QPixmap
from PyQt5.QtWidgets import QWidget, QMenu, QApplication
from libs.shape import Shape from libs.shape import Shape
from libs.utils import distance from libs.utils import distance
import copy
CURSOR_DEFAULT = Qt.ArrowCursor CURSOR_DEFAULT = Qt.ArrowCursor
CURSOR_POINT = Qt.PointingHandCursor CURSOR_POINT = Qt.PointingHandCursor
...@@ -31,8 +25,6 @@ CURSOR_DRAW = Qt.CrossCursor ...@@ -31,8 +25,6 @@ CURSOR_DRAW = Qt.CrossCursor
CURSOR_MOVE = Qt.ClosedHandCursor CURSOR_MOVE = Qt.ClosedHandCursor
CURSOR_GRAB = Qt.OpenHandCursor CURSOR_GRAB = Qt.OpenHandCursor
# class Canvas(QGLWidget):
class Canvas(QWidget): class Canvas(QWidget):
zoomRequest = pyqtSignal(int) zoomRequest = pyqtSignal(int)
...@@ -129,7 +121,6 @@ class Canvas(QWidget): ...@@ -129,7 +121,6 @@ class Canvas(QWidget):
def selectedVertex(self): def selectedVertex(self):
return self.hVertex is not None return self.hVertex is not None
def mouseMoveEvent(self, ev): def mouseMoveEvent(self, ev):
"""Update line with last point and current coordinates.""" """Update line with last point and current coordinates."""
pos = self.transformPos(ev.pos()) pos = self.transformPos(ev.pos())
...@@ -333,7 +324,6 @@ class Canvas(QWidget): ...@@ -333,7 +324,6 @@ class Canvas(QWidget):
self.movingShape = False self.movingShape = False
def endMove(self, copy=False): def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes) assert len(self.selectedShapesCopy) == len(self.selectedShapes)
...@@ -410,7 +400,6 @@ class Canvas(QWidget): ...@@ -410,7 +400,6 @@ class Canvas(QWidget):
self.selectionChanged.emit(shapes) self.selectionChanged.emit(shapes)
self.update() self.update()
def selectShapePoint(self, point, multiple_selection_mode): def selectShapePoint(self, point, multiple_selection_mode):
"""Select the first shape created which contains this point.""" """Select the first shape created which contains this point."""
if self.selectedVertex(): # A vertex is marked for selection. if self.selectedVertex(): # A vertex is marked for selection.
...@@ -494,7 +483,6 @@ class Canvas(QWidget): ...@@ -494,7 +483,6 @@ class Canvas(QWidget):
else: else:
shape.moveVertexBy(index, shiftPos) shape.moveVertexBy(index, shiftPos)
def boundedMoveShape(self, shapes, pos): def boundedMoveShape(self, shapes, pos):
if type(shapes).__name__ != 'list': shapes = [shapes] if type(shapes).__name__ != 'list': shapes = [shapes]
if self.outOfPixmap(pos): if self.outOfPixmap(pos):
...@@ -515,6 +503,7 @@ class Canvas(QWidget): ...@@ -515,6 +503,7 @@ class Canvas(QWidget):
if dp: if dp:
for shape in shapes: for shape in shapes:
shape.moveBy(dp) shape.moveBy(dp)
shape.close()
self.prevPoint = pos self.prevPoint = pos
return True return True
return False return False
...@@ -728,6 +717,31 @@ class Canvas(QWidget): ...@@ -728,6 +717,31 @@ class Canvas(QWidget):
self.moveOnePixel('Up') self.moveOnePixel('Up')
elif key == Qt.Key_Down and self.selectedShapes: elif key == Qt.Key_Down and self.selectedShapes:
self.moveOnePixel('Down') self.moveOnePixel('Down')
elif key == Qt.Key_X and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(0.01):
continue
self.selectedShape.rotate(0.01)
self.shapeMoved.emit()
self.update()
elif key == Qt.Key_C and self.selectedShapes:
for i in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[i]
if self.rotateOutOfBound(-0.01):
continue
self.selectedShape.rotate(-0.01)
self.shapeMoved.emit()
self.update()
def rotateOutOfBound(self, angle):
for shape in range(len(self.selectedShapes)):
self.selectedShape = self.selectedShapes[shape]
for i, p in enumerate(self.selectedShape.points):
if self.outOfPixmap(self.selectedShape.rotatePoint(p, angle)):
return True
return False
def moveOnePixel(self, direction): def moveOnePixel(self, direction):
# print(self.selectedShape.points) # print(self.selectedShape.points)
......
import sys, time # !/usr/bin/env python
from PyQt5 import QtWidgets # -*- coding: utf-8 -*-
from PyQt5.QtGui import * from PyQt5.QtCore import QModelIndex
from PyQt5.QtCore import * from PyQt5.QtWidgets import QListWidget
from PyQt5.QtWidgets import *
class EditInList(QListWidget): class EditInList(QListWidget):
def __init__(self): def __init__(self):
super(EditInList,self).__init__() super(EditInList, self).__init__()
# click to edit self.edited_item = None
self.clicked.connect(self.item_clicked)
def item_clicked(self, modelindex: QModelIndex) -> None: def item_clicked(self, modelindex: QModelIndex):
self.edited_item = self.currentItem() try:
if self.edited_item is not None:
self.closePersistentEditor(self.edited_item) self.closePersistentEditor(self.edited_item)
item = self.item(modelindex.row()) except:
# time.sleep(0.2) self.edited_item = self.currentItem()
self.edited_item = item
self.openPersistentEditor(item) self.edited_item = self.item(modelindex.row())
# time.sleep(0.2) self.openPersistentEditor(self.edited_item)
self.editItem(item) self.editItem(self.edited_item)
def mouseDoubleClickEvent(self, event): def mouseDoubleClickEvent(self, event):
# close edit pass
for i in range(self.count()):
self.closePersistentEditor(self.item(i))
def leaveEvent(self, event): def leaveEvent(self, event):
# close edit # close edit
......
...@@ -10,19 +10,14 @@ ...@@ -10,19 +10,14 @@
# SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF # SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
# CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
#!/usr/bin/python # !/usr/bin/python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import math
import sys
from PyQt5.QtCore import QPointF
try: from PyQt5.QtGui import QColor, QPen, QPainterPath, QFont
from PyQt5.QtGui import *
from PyQt5.QtCore import *
except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
from libs.utils import distance from libs.utils import distance
import sys
DEFAULT_LINE_COLOR = QColor(0, 255, 0, 128) DEFAULT_LINE_COLOR = QColor(0, 255, 0, 128)
DEFAULT_FILL_COLOR = QColor(255, 0, 0, 128) DEFAULT_FILL_COLOR = QColor(255, 0, 0, 128)
...@@ -59,6 +54,8 @@ class Shape(object): ...@@ -59,6 +54,8 @@ class Shape(object):
self.difficult = difficult self.difficult = difficult
self.paintLabel = paintLabel self.paintLabel = paintLabel
self.locked = False self.locked = False
self.direction = 0
self.center = None
self._highlightIndex = None self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX self._highlightMode = self.NEAR_VERTEX
self._highlightSettings = { self._highlightSettings = {
...@@ -74,7 +71,24 @@ class Shape(object): ...@@ -74,7 +71,24 @@ class Shape(object):
# is used for drawing the pending line a different color. # is used for drawing the pending line a different color.
self.line_color = line_color self.line_color = line_color
def rotate(self, theta):
for i, p in enumerate(self.points):
self.points[i] = self.rotatePoint(p, theta)
self.direction -= theta
self.direction = self.direction % (2 * math.pi)
def rotatePoint(self, p, theta):
order = p - self.center
cosTheta = math.cos(theta)
sinTheta = math.sin(theta)
pResx = cosTheta * order.x() + sinTheta * order.y()
pResy = - sinTheta * order.x() + cosTheta * order.y()
pRes = QPointF(self.center.x() + pResx, self.center.y() + pResy)
return pRes
def close(self): def close(self):
self.center = QPointF((self.points[0].x() + self.points[2].x()) / 2,
(self.points[0].y() + self.points[2].y()) / 2)
self._closed = True self._closed = True
def reachMaxPoints(self): def reachMaxPoints(self):
...@@ -83,7 +97,9 @@ class Shape(object): ...@@ -83,7 +97,9 @@ class Shape(object):
return False return False
def addPoint(self, point): def addPoint(self, point):
if not self.reachMaxPoints(): # 4个点时发出close信号 if self.reachMaxPoints():
self.close()
else:
self.points.append(point) self.points.append(point)
def popPoint(self): def popPoint(self):
...@@ -112,7 +128,7 @@ class Shape(object): ...@@ -112,7 +128,7 @@ class Shape(object):
# Uncommenting the following line will draw 2 paths # Uncommenting the following line will draw 2 paths
# for the 1st vertex, and make it non-filled, which # for the 1st vertex, and make it non-filled, which
# may be desirable. # may be desirable.
#self.drawVertex(vrtx_path, 0) # self.drawVertex(vrtx_path, 0)
for i, p in enumerate(self.points): for i, p in enumerate(self.points):
line_path.lineTo(p) line_path.lineTo(p)
...@@ -136,9 +152,9 @@ class Shape(object): ...@@ -136,9 +152,9 @@ class Shape(object):
font.setPointSize(8) font.setPointSize(8)
font.setBold(True) font.setBold(True)
painter.setFont(font) painter.setFont(font)
if(self.label == None): if self.label is None:
self.label = "" self.label = ""
if(min_y < MIN_Y_LABEL): if min_y < MIN_Y_LABEL:
min_y += MIN_Y_LABEL min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, self.label) painter.drawText(min_x, min_y, self.label)
...@@ -198,6 +214,8 @@ class Shape(object): ...@@ -198,6 +214,8 @@ class Shape(object):
def copy(self): def copy(self):
shape = Shape("%s" % self.label) shape = Shape("%s" % self.label)
shape.points = [p for p in self.points] shape.points = [p for p in self.points]
shape.center = self.center
shape.direction = self.direction
shape.fill = self.fill shape.fill = self.fill
shape.selected = self.selected shape.selected = self.selected
shape._closed = self._closed shape._closed = self._closed
......
...@@ -50,7 +50,7 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训 ...@@ -50,7 +50,7 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训
|模型名称|模型简介|配置文件|推理模型大小|下载地址| |模型名称|模型简介|配置文件|推理模型大小|下载地址|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|ch_PP-OCRv2_rec_slim|【最新】slim量化版超轻量模型,支持中英文、数字识别|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) | |ch_PP-OCRv2_rec_slim|【最新】slim量化版超轻量模型,支持中英文、数字识别|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|ch_PP-OCRv2_rec|【最新】原始超轻量模型,支持中英文、数字识别|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)|8.5M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) | |ch_PP-OCRv2_rec|【最新】原始超轻量模型,支持中英文、数字识别|[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)|8.5M|[推理模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)| 6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | |ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)| 6M |[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|ch_ppocr_mobile_v2.0_rec|原始超轻量模型,支持中英文、数字识别|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)|5.2M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) | |ch_ppocr_mobile_v2.0_rec|原始超轻量模型,支持中英文、数字识别|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)|5.2M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) |
|ch_ppocr_server_v2.0_rec|通用模型,支持中英文、数字识别|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) | |ch_ppocr_server_v2.0_rec|通用模型,支持中英文、数字识别|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) |
......
...@@ -16,22 +16,24 @@ PaddleOCR希望可以通过AI的力量助力任何一位有梦想的开发者实 ...@@ -16,22 +16,24 @@ PaddleOCR希望可以通过AI的力量助力任何一位有梦想的开发者实
### 1.1 基于PaddleOCR的社区项目 ### 1.1 基于PaddleOCR的社区项目
- 【最新】 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel):完整的C#版本标注工具 (@ [包建强](https://gitee.com/BaoJianQiang) ) | 类别 | 项目 | 描述 | 开发者 |
| -------- | ------------------------------------------------------------ | -------------------------- | ------------------------------------------------------------ |
#### 1.1.1 通用工具 | 通用工具 | [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel) | 完整的C#版本标注GUI | [包建强](https://gitee.com/BaoJianQiang) |
| 通用工具 | [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR) | 通用型桌面级即时翻译GUI | [PantsuDango](https://github.com/PantsuDango) |
- [DangoOCR离线版](https://github.com/PantsuDango/DangoOCR):通用型桌面级即时翻译工具 (@ [PantsuDango](https://github.com/PantsuDango)) | 通用工具 | [scr2txt](https://github.com/lstwzd/scr2txt) | 截屏转文字GUI | [lstwzd](https://github.com/lstwzd) |
- [scr2txt](https://github.com/lstwzd/scr2txt):截屏转文字工具 (@ [lstwzd](https://github.com/lstwzd)) | 通用工具 | [ocr_sdk](https://github.com/mymagicpower/AIAS/blob/main/1_image_sdks/text_recognition/ocr_sdk) | OCR java SDK工具箱 | [Calvin](https://github.com/mymagicpower) |
- [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0):英文视频自动生成字幕( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052)) | 通用工具 | [iocr](https://github.com/mymagicpower/AIAS/blob/main/8_suite_hub/iocr) | IOCR 自定义模板识别(支持表格识别) | [Calvin](https://github.com/mymagicpower) |
| 通用工具 | [Lmdb Dataset Format Conversion Tool](https://github.com/OneYearIsEnough/PaddleOCR-Recog-LmdbDataset-Conversion) | 文本识别任务中lmdb数据格式转换工具 | [OneYearIsEnough](https://github.com/OneYearIsEnough) |
#### 1.1.2 垂类场景工具 | 垂类工具 | [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0) | 英文视频自动生成字幕 | [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052) |
| 垂类工具 | [id_card_ocr](https://github.com/baseli/id_card_ocr) | 身份证复印件识别 | [baseli](https://github.com/baseli) |
- [id_card_ocr](https://github.com/baseli/id_card_ocr):身份证复印件识别(@ [baseli](https://github.com/baseli)) | 垂类工具 | [Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader) | 能看懂表格图片的数据助手 | [thunder95](https://github.com/thunder95]) |
- [Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader):能看懂表格图片的数据助手(@ [thunder95](https://github.com/thunder95])) | 垂类工具 | [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/3382897) | OCR流程中对手写体进行过滤 | [daassh](https://github.com/daassh) |
| 垂类工具 | [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/2803693) | 电表读数和编号识别 | [深渊上的坑](https://github.com/edencfc) |
#### 1.1.3 前后处理 | 前后处理 | [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs) | 获取OCR识别结果的key-value | [yuranusduke](https://github.com/yuranusduke) |
|前处理| [optlab](https://github.com/GreatV/optlab) |OCR前处理工具箱,基于Qt和Leptonica。|[GreatV](https://github.com/GreatV)|
- [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs):获取OCR识别结果的key-value(@ [yuranusduke](https://github.com/yuranusduke)) |应用部署| [PaddleOCRSharp](https://github.com/raoyutian/PaddleOCRSharp) |PaddleOCR的.NET封装与应用部署。|[raoyutian](https://github.com/raoyutian/PaddleOCRSharp)|
|应用部署| [PaddleSharp](https://github.com/sdcb/PaddleSharp) |PaddleOCR的.NET封装与应用部署,支持跨平台、GPU|[sdcb](https://github.com/sdcb)|
| 学术前沿模型训练与推理 | [AI Studio项目](https://aistudio.baidu.com/aistudio/projectdetail/3397137) | StarNet-MobileNetV3算法–中文训练 | [xiaoyangyang2](https://github.com/xiaoyangyang2) |
### 1.2 为PaddleOCR新增功能 ### 1.2 为PaddleOCR新增功能
......
...@@ -43,8 +43,8 @@ Relationship of the above models is as follows. ...@@ -43,8 +43,8 @@ Relationship of the above models is as follows.
|model name|description|config|model size|download| |model name|description|config|model size|download|
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
|ch_PP-OCRv2_rec_slim|[New] Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) | |ch_PP-OCRv2_rec_slim|[New] Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text recognition|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
|ch_PP-OCRv2_rec|[New] Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)|8.5M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) | |ch_PP-OCRv2_rec|[New] Original lightweight model, supporting Chinese, English, multilingual text recognition|[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)|8.5M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)| 6M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | |ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)| 6M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|ch_ppocr_mobile_v2.0_rec|Original lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)|5.2M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) | |ch_ppocr_mobile_v2.0_rec|Original lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)|5.2M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) |
|ch_ppocr_server_v2.0_rec|General model, supporting Chinese, English and number recognition|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) | |ch_ppocr_server_v2.0_rec|General model, supporting Chinese, English and number recognition|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) |
......
doc/joinus.PNG

189.3 KB | W: | H:

doc/joinus.PNG

198.7 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -81,7 +81,7 @@ ...@@ -81,7 +81,7 @@
"\n", "\n",
"如果对某些层使用更小的学习率学习,静态图里还不是很方便,一个方法是在参数初始化的时候,给权重的属性设置固定的学习率,参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/fluid/param_attr/ParamAttr_cn.html#paramattr\n", "如果对某些层使用更小的学习率学习,静态图里还不是很方便,一个方法是在参数初始化的时候,给权重的属性设置固定的学习率,参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/fluid/param_attr/ParamAttr_cn.html#paramattr\n",
"\n", "\n",
"实际上我们实验发现,直接加载模型去fine-tune,不设置某些层不同学习率,效果也都不错\n", "实际上我们实验发现,直接加载模型去fine-tune,不设置某些层不同学习率,效果也都不错\n",
"\n", "\n",
"**1.11 DB的预处理部分,图片的长和宽为什么要处理成32的倍数?**\n", "**1.11 DB的预处理部分,图片的长和宽为什么要处理成32的倍数?**\n",
"\n", "\n",
...@@ -95,7 +95,7 @@ ...@@ -95,7 +95,7 @@
"\n", "\n",
"**1.13 PP-OCR检测效果不好,该如何优化?**\n", "**1.13 PP-OCR检测效果不好,该如何优化?**\n",
"\n", "\n",
"A: 具体问题具体分析:\n", "**A**: 具体问题具体分析:\n",
"- 如果在你的场景上检测效果不可用,首选是在你的数据上做finetune训练;\n", "- 如果在你的场景上检测效果不可用,首选是在你的数据上做finetune训练;\n",
"- 如果图像过大,文字过于密集,建议不要过度压缩图像,可以尝试修改检测预处理的resize逻辑,防止图像被过度压缩;\n", "- 如果图像过大,文字过于密集,建议不要过度压缩图像,可以尝试修改检测预处理的resize逻辑,防止图像被过度压缩;\n",
"- 检测框大小过于紧贴文字或检测框过大,可以调整db_unclip_ratio这个参数,加大参数可以扩大检测框,减小参数可以减小检测框大小;\n", "- 检测框大小过于紧贴文字或检测框过大,可以调整db_unclip_ratio这个参数,加大参数可以扩大检测框,减小参数可以减小检测框大小;\n",
...@@ -123,8 +123,8 @@ ...@@ -123,8 +123,8 @@
"\n", "\n",
"**A**:GPU加速预测推荐使用TensorRT。\n", "**A**:GPU加速预测推荐使用TensorRT。\n",
"- 1. 从[链接](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html)下载带TensorRT的Paddle安装包或者预测库。\n", "- 1. 从[链接](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html)下载带TensorRT的Paddle安装包或者预测库。\n",
"- 2. 从Nvidia官网下载TensorRT版本,注意下载的TensorRT版本与paddle安装包中编译的TensorRT版本一致。\n", "- 2. 从Nvidia官网下载[TensorRT](https://developer.nvidia.com/tensorrt),注意下载的TensorRT版本与paddle安装包中编译的TensorRT版本一致。\n",
"- 3. 设置环境变量LD_LIBRARY_PATH,指向TensorRT的lib文件夹\n", "- 3. 设置环境变量`LD_LIBRARY_PATH`,指向TensorRT的lib文件夹\n",
"```\n", "```\n",
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<TensorRT-${version}/lib>\n", "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<TensorRT-${version}/lib>\n",
"```\n", "```\n",
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
"collapsed": false "collapsed": false
}, },
"source": [ "source": [
"# OCR七日课之文本检测综述\n" "# 文本检测算法理论\n"
] ]
}, },
{ {
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
"collapsed": false "collapsed": false
}, },
"source": [ "source": [
"## 1. 文本检测\n", "## 1 文本检测\n",
"\n", "\n",
"文本检测任务是找出图像或视频中的文字位置。不同于目标检测任务,目标检测不仅要解决定位问题,还要解决目标分类问题。\n", "文本检测任务是找出图像或视频中的文字位置。不同于目标检测任务,目标检测不仅要解决定位问题,还要解决目标分类问题。\n",
"\n", "\n",
"文本在图像中的表现形式可以视为一种‘目标,通用的目标检测的方法也适用于文本检测,从任务本身上来看:\n", "文本在图像中的表现形式可以视为一种‘目标,通用的目标检测的方法也适用于文本检测,从任务本身上来看:\n",
"\n", "\n",
"- 目标检测:给定图像或者视频,找出目标的位置(box),并给出目标的类别;\n", "- 目标检测:给定图像或者视频,找出目标的位置(box),并给出目标的类别;\n",
"- 文本检测:给定输入图像或者视频,找出文本的区域,可以是单字符位置或者整个文本行位置;\n", "- 文本检测:给定输入图像或者视频,找出文本的区域,可以是单字符位置或者整个文本行位置;\n",
...@@ -41,14 +41,14 @@ ...@@ -41,14 +41,14 @@
"1. 自然场景中文本具有多样性:文本检测受到文字颜色、大小、字体、形状、方向、语言、以及文本长度的影响;\n", "1. 自然场景中文本具有多样性:文本检测受到文字颜色、大小、字体、形状、方向、语言、以及文本长度的影响;\n",
"2. 复杂的背景和干扰;文本检测受到图像失真,模糊,低分辨率,阴影,亮度等因素的影响;\n", "2. 复杂的背景和干扰;文本检测受到图像失真,模糊,低分辨率,阴影,亮度等因素的影响;\n",
"3. 文本密集甚至重叠会影响文字的检测;\n", "3. 文本密集甚至重叠会影响文字的检测;\n",
"4. 文字存在局部一致性,文本行的一小部分,也可视为是独立的文本;\n", "4. 文字存在局部一致性:文本行的一小部分,也可视为是独立的文本。\n",
"\n", "\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/072f208f2aff47e886cf2cf1378e23c648356686cf1349c799b42f662d8ced00\"\n", "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/072f208f2aff47e886cf2cf1378e23c648356686cf1349c799b42f662d8ced00\"\n",
"width=\"1000\" ></center>\n", "width=\"1000\" ></center>\n",
"\n", "\n",
"<br><center>图3 文本检测场景</center>\n", "<br><center>图3 文本检测场景</center>\n",
"\n", "\n",
"针对以上问题,衍生了很多基于深度学习的文本检测算法,解决自然场景文字检测问题,这些方法可以分为基于回归和基于分割的文本检测方法。\n", "针对以上问题,衍生出了很多基于深度学习的文本检测算法,用于解决自然场景文字检测问题。这些方法可以分为基于回归和基于分割的文本检测方法。\n",
"\n", "\n",
"下一节将简要介绍基于深度学习技术的经典文字检测算法。" "下一节将简要介绍基于深度学习技术的经典文字检测算法。"
] ]
...@@ -59,7 +59,7 @@ ...@@ -59,7 +59,7 @@
"collapsed": false "collapsed": false
}, },
"source": [ "source": [
"## 2. 文本检测方法介绍\n", "## 2 文本检测方法介绍\n",
"\n", "\n",
"\n", "\n",
"近些年来基于深度学习的文本检测算法层出不穷,这些方法大致可以分为两类:\n", "近些年来基于深度学习的文本检测算法层出不穷,这些方法大致可以分为两类:\n",
...@@ -134,7 +134,7 @@ ...@@ -134,7 +134,7 @@
"\n", "\n",
"\n", "\n",
"\n", "\n",
"LOMO[19]针对长文本和弯曲文本问题,提出迭代的优化文本定位特征获取更精细的文本定位,该方法包括三个部分,坐标回归模块DR,迭代优化模块IRM以及任意形状表达模块SEM。分别用于生成文本大致区域,迭代优化文本定位特征,预测文本区域、文本中心线以及文本边界。迭代的优化文本特征可以更好的解决长文本定位问题以及获得更精确的文本区域定位。\n", "LOMO[19]针对长文本和弯曲文本问题,提出迭代的优化文本定位特征获取更精细的文本定位。该方法包括三个部分:坐标回归模块DR,迭代优化模块IRM以及任意形状表达模块SEM。它们分别用于生成文本大致区域,迭代优化文本定位特征,预测文本区域、文本中心线以及文本边界。迭代的优化文本特征可以更好的解决长文本定位问题以及获得更精确的文本区域定位。\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/e90adf3ca25a45a0af0b84a181fbe2c4954be1fcca8f4049957128548b7131ef\"\n", "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/e90adf3ca25a45a0af0b84a181fbe2c4954be1fcca8f4049957128548b7131ef\"\n",
"width=\"1000\" ></center>\n", "width=\"1000\" ></center>\n",
"<br><center>图11 LOMO框架图</center>\n", "<br><center>图11 LOMO框架图</center>\n",
...@@ -228,7 +228,7 @@ ...@@ -228,7 +228,7 @@
"collapsed": false "collapsed": false
}, },
"source": [ "source": [
"## 3. 总结\n", "## 3 总结\n",
"\n", "\n",
"本节介绍了近几年来文本检测领域的发展,包括基于回归、分割的文本检测方法,并分别列举并介绍了一些经典论文的方法思路。下一节以PaddleOCR开源库为例,详细介绍DBNet的算法原理以及核心代码实现。" "本节介绍了近几年来文本检测领域的发展,包括基于回归、分割的文本检测方法,并分别列举并介绍了一些经典论文的方法思路。下一节以PaddleOCR开源库为例,详细介绍DBNet的算法原理以及核心代码实现。"
] ]
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
"\n", "\n",
"然后安装第三方库:\n", "然后安装第三方库:\n",
"\n", "\n",
"```\n", "```bash\n",
"cd PaddleOCR\n", "cd PaddleOCR\n",
"pip3 install -r requirements.txt\n", "pip3 install -r requirements.txt\n",
"```\n", "```\n",
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text detection FAQ\n",
"\n",
"This section lists some of the problems that developers often encounter when using PaddleOCR's text detection model, and gives corresponding solutions or suggestions.\n",
"\n",
"The FAQ is introduced in two parts, namely:\n",
" -Text detection training related\n",
" -Text detection and prediction related"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. FAQ about Text Detection Training\n",
"\n",
"**1.1 What are the text detection algorithms provided by PaddleOCR?**\n",
"\n",
"**A**: PaddleOCR contains a variety of text detection models, including regression-based text detection methods EAST and SAST, and segmentation-based text detection methods DB, PSENet.\n",
"\n",
"\n",
"**1.2: What data sets are used in the Chinese ultra-lightweight and general models in the PaddleOCR project? How many samples were trained, what configuration of GPUs, how many epochs were run, and how long did they run?**\n",
"\n",
"**A**: For the ultra-lightweight DB detection model, the training data includes open source data sets lsvt, rctw, CASIA, CCPD, MSRA, MLT, BornDigit, iflytek, SROIE and synthetic data sets, etc. The total data volume is 10W, The data set is divided into 5 parts. A random sampling strategy is used during training. The training takes about 500 epochs on a 4-card V100GPU, which takes 3 days.\n",
"\n",
"\n",
"**1.3 Does the text detection training label require specific text labeling? What does the \"###\" in the label mean?**\n",
"\n",
"**A**: Text detection training only needs the coordinates of the text area. The label can be four or fourteen points, arranged in the order of upper left, upper right, lower right, and lower left. The label file provided by PaddleOCR contains text fields. For unclear text in the text area, ### will be used instead. When training the detection model, the text field in the label will not be used.\n",
" \n",
"**1.4 Is the effect of the text detection model trained when the text lines are tight?**\n",
"\n",
"**A**: When using segmentation-based methods, such as DB, to detect dense text lines, it is best to collect a batch of data for training, and during training, a binary image will be generated [shrink_ratio](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/ppocr/data/imaug/make_shrink_map.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L37)Turn down the parameter. In addition, when forecasting, you can appropriately reduce [unclip_ratio](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L59) parameter, the larger the unclip_ratio parameter value, the larger the detection frame.\n",
"\n",
"\n",
"**1.5 For some large-sized document images, DB will have more missed inspections during inspection. How to avoid this kind of missed inspections?**\n",
"\n",
"**A**: First of all, you need to determine whether the model is not well-trained or is the problem handled during prediction. If the model is not well trained, it is recommended to add more data for training, or add more data to enhance it during training.\n",
"If the problem is that the predicted image is too large, you can increase the longest side setting parameter [det_limit_side_len] entered during prediction [det_limit_side_len](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L47), which is 960 by default.\n",
"Secondly, you can observe whether the missed text has segmentation results by visualizing the post-processed segmentation map. If there is no segmentation result, the model is not well trained. If there is a complete segmentation area, it means that it is a problem of post-prediction processing. In this case, it is recommended to adjust [DB post-processing parameters](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L51-L53)。\n",
"\n",
"\n",
"**1.6 The problem of missed detection of DB model bending text (such as a slightly deformed document image)?**\n",
"\n",
"**A**: When calculating the average score of the text box in the DB post-processing, it is the average score of the rectangle area, which is easy to cause the missed detection of the curved text. The average score of the polygon area has been added, which will be more accurate, but the speed is somewhat different. Decrease, can be selected as needed, and you can view the [Visual Contrast Effect] (https://github.com/PaddlePaddle/PaddleOCR/pull/2604) in the relevant pr. This function is selected by the parameter [det_db_score_mode](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/tools/infer/utility.py#L51), the parameter value is optional [`fast` (default) , `slow`], `fast` corresponds to the original rectangle mode, and `slow` corresponds to the polygon mode. Thanks to the user [buptlihang](https://github.com/buptlihang) for mentioning [pr](https://github.com/PaddlePaddle/PaddleOCR/pull/2574) to help solve this problem.\n",
"\n",
"\n",
"**1.7 For simple OCR tasks with low accuracy requirements, how many data sets do I need to prepare?**\n",
"\n",
"**A**: (1) The amount of training data is related to the complexity of the problem to be solved. The greater the difficulty and the higher the accuracy requirements, the greater the data set requirements, and in general, the more training data in practice, the better the effect.\n",
"\n",
"(2) For scenes with low accuracy requirements, the amount of data required for detection tasks and recognition tasks is different. For inspection tasks, 500 images can guarantee the basic inspection results. For recognition tasks, it is necessary to ensure that the number of line text images in which each character in the recognition dictionary appears in different scenes needs to be greater than 200 (for example, if there are 5 words in the dictionary, each word needs to appear in more than 200 pictures, then The minimum required number of images should be between 200-1000), so that the basic recognition effect can be guaranteed.\n",
"\n",
"\n",
"**1.8 How to get more data when the amount of training data is small?**\n",
"\n",
"**A**: When the amount of training data is small, you can try the following three ways to get more data: (1) Collect more training data manually, the most direct and effective way. (2) Basic image processing or transformation based on PIL and opencv. For example, the three modules of ImageFont, Image, ImageDraw in PIL write text into the background, opencv's rotating affine transformation, Gaussian filtering and so on. (3) Synthesize data using data generation algorithms, such as algorithms such as pix2pix.\n",
"\n",
"\n",
"**1.9 How to replace the backbone of text detection/recognition?**\n",
"\n",
"A: Whether it is text detection or text recognition, the choice of backbone network is a trade-off between prediction effect and prediction efficiency. Generally, if you choose a larger-scale backbone network, such as ResNet101_vd, the detection or recognition will be more accurate, but the prediction time will increase accordingly. However, choosing a smaller-scale backbone network, such as MobileNetV3_small_x0_35, will predict faster, but the accuracy of detection or recognition will be greatly reduced. Fortunately, the detection or recognition effects of different backbone networks are positively correlated with the image 1000 classification task in the ImageNet dataset. PaddleClas, a flying paddle image classification suite, summarizes 23 series of classification network structures such as ResNet_vd, Res2Net, HRNet, MobileNetV3, GhostNet, etc. The top1 recognition accuracy rate of the above image classification task, GPU (V100 and T4) and CPU (Snapdragon 855) The prediction time-consuming and the corresponding 117 pre-training model download addresses.\n",
"\n",
"(1) The replacement of the text detection backbone network is mainly to determine 4 stages similar to ResNet to facilitate the integration of subsequent detection heads similar to FPN. In addition, for the text detection problem, the classification pre-training model trained by ImageNet can accelerate the convergence and improve the effect.\n",
"\n",
"(2) The replacement of the backbone network for text recognition requires attention to the drop position of the network width and height stride. Since text recognition generally has a large ratio of width to height, the frequency of height reduction is less, and the frequency of width reduction is more. You can refer to [Changes to the MobileNetV3 backbone network in PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/ppocr/modeling/backbones/rec_mobilenet_v3.py)。\n",
"\n",
"\n",
"**1.10 How to finetune the detection model, such as freezing the previous layer or learning with a small learning rate for some layers?**\n",
"\n",
"**A**: If you freeze certain layers, you can set the stop_gradient property of the variable to True, so that all the parameters before calculating this variable will not be updated, refer to: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/faq/train_cn.html#id4\n",
"\n",
"If learning with a smaller learning rate for some layers is not very convenient in the static graph, one method is to set a fixed learning rate for the weight attribute when the parameters are initialized, refer to: https://www.paddlepaddle.org.cn/documentation/docs/en/develop/api/paddle/fluid/param_attr/ParamAttr_cn.html#paramattr\n",
"\n",
"In fact, our experiment found that directly loading the model to fine-tune without setting different learning rates of certain layers, the effect is also good.\n",
"\n",
"**1.11 In the preprocessing part of DB, why should the length and width of the picture be processed into multiples of 32?**\n",
"\n",
"**A**: It is related to the stride of the network downsampling. Take the resnet backbone network under inspection as an example. After the image is input to the network, it needs to be downsampled by 2 times for 5 times, a total of 32 times. Therefore, it is recommended that the input image size be a multiple of 32.\n",
"\n",
"\n",
"**1.12 In the PP-OCR series models, why does the backbone network for text detection not use SEBlock?**\n",
"\n",
"**A**: The SE module is an important module of the MobileNetV3 network. Its purpose is to estimate the importance of each feature channel of the feature map, assign weights to each feature of the feature map, and improve the expressive ability of the network. However, for text detection, the resolution of the input network is relatively large, generally 640\\*640. It is difficult to use the SE module to estimate the importance of each feature channel of the feature map. The network improvement ability is limited, but the module is relatively time-consuming. In the PP-OCR system, the backbone network for text detection does not use the SE module. Experiments also show that when the SE module is removed, the size of the ultra-lightweight model can be reduced by 40%, and the text detection effect is basically not affected. For details, please refer to the PP-OCR technical article, https://arxiv.org/abs/2009.09941.\n",
"\n",
"\n",
"**1.13 The PP-OCR detection effect is not good, how to optimize it?**\n",
"\n",
"**A**: Specific analysis of specific issues:\n",
"- If the detection effect is not available on your scene, the first choice is to do finetune training on your data;\n",
"- If the image is too large and the text is too dense, it is recommended not to over-compress the image. You can try to modify the resize logic of the detection preprocessing to prevent the image from being over-compressed;\n",
"- The size of the detection frame is too close to the text or the detection frame is too large, you can adjust the db_unclip_ratio parameter, increasing the parameter can enlarge the detection frame, and reducing the parameter can reduce the size of the detection frame;\n",
"- There are many missed detection problems in the detection frame, which can reduce the threshold parameter det_db_box_thresh for DB detection to prevent some detection frames from being filtered out. You can also try to set det_db_score_mode to'slow';\n",
"- Other methods can choose use_dilation as True to expand the feature map of the detection output. In general, the effect will be improved.\n",
"\n",
"\n",
"## 2. FAQ about Text Detection and Prediction\n",
"\n",
"**2.1 In DB, some boxes are too pasted with text, but some corners of the text are removed to affect the recognition. Is there any way to alleviate this problem?**\n",
"\n",
"**A**: The post-processing parameter [unclip_ratio](https://github.com/PaddlePaddle/PaddleOCR/blob/d80afce9b51f09fd3d90e539c40eba8eb5e50dd6/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L52) can be appropriately increased. the larger the parameter, the larger the text box.\n",
"\n",
"\n",
"**2.2 Why does the PaddleOCR detection prediction only support one image test? That is, test_batch_size_per_card=1**\n",
"\n",
"**A**: When predicting, the image is scaled in equal proportions, the longest side is 960, and the length and width of different images after scaling in equal proportions are inconsistent, and they cannot form a batch, so set test_batch_size to 1.\n",
"\n",
"\n",
"**2.3 Accelerate PaddleOCR's text detection model prediction on the CPU?**\n",
"\n",
"**A**: x86 CPU can use mkldnn (OneDNN) for acceleration; enable [enable_mkldnn](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py#L105) Parameters. In addition, in conjunction with increasing the number of threads used for prediction on the CPU, [num_threads](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py#L106) can effectively speed up the prediction speed on the CPU.\n",
"\n",
"**2.4 Accelerate PaddleOCR's text detection model prediction on GPU?**\n",
"\n",
"**A**: TensorRT is recommended for GPU accelerated prediction.\n",
"- 1. Download the Paddle installation package or prediction library with TensorRT from [link](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html).\n",
"- 2. Download the [TensorRT](https://developer.nvidia.com/tensorrt) from the Nvidia official website. Note that the downloaded TensorRT version is consistent with the TensorRT version compiled in the paddle installation package.\n",
"- 3. Set the environment variable `LD_LIBRARY_PATH` to point to the lib folder of TensorRT\n",
"```\n",
"export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<TensorRT-${version}/lib>\n",
"```\n",
"- 4. Enable [tensorrt option](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L38).\n",
"\n",
"**2.5 How to deploy PaddleOCR model on the mobile terminal?**\n",
"\n",
"**A**: Flying Oar Paddle has a special tool for mobile deployment [PaddleLite](https://github.com/PaddlePaddle/Paddle-Lite), and PaddleOCR provides DB+CRNN as the demo android arm deployment code , Refer to [link](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/deploy/lite/readme.md).\n",
"\n",
"\n",
"**2.6 How to use PaddleOCR multi-process prediction?**\n",
"\n",
"**A**: PaddleOCR recently added [Multi-Process Predictive Control Parameters](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L111), `use_mp` indicates whether When using multiple processes, `total_process_num` indicates the number of processes when using multiple processes. For specific usage, please refer to [document](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/doc/doc_ch/inference.md#1-%E8%B6%85%E8%BD%BB%E9%87%8F%E4%B8%AD%E6%96%87ocr%E6%A8%A1%E5%9E%8B%E6%8E%A8%E7%90%86).\n",
"\n",
"**2.7 Video memory explosion and memory leak during prediction?**\n",
"\n",
"**A**: If it is the prediction of the training model, the video memory is not enough because the model is too large or the input image is too large, you can refer to the code and add paddle.no_grad() before the main function runs to reduce the video memory usage. If the memory usage of the inference model is too high, you can add [config.enable_memory_optim()](https://github.com/PaddlePaddle/PaddleOCR/blob/8b656a3e13631dfb1ac21d2095d4d4a4993ef710/tools/infer/utility.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L267) to reduce the memory usage when configuring Config.\n",
"\n",
"In addition, regarding the memory leak when using Paddle to predict, it is recommended to install the latest version of paddle. The memory leak has been fixed."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# 1. Course Prerequisites\n",
"\n",
"The OCR model involved in this course is based on deep learning, so its related basic knowledge, environment configuration, project engineering and other materials will be introduced in this section, especially for readers who are not familiar with deep learning. content.\n",
"\n",
"### 1.1 Preliminary Knowledge\n",
"\n",
"The \"learning\" of deep learning has been developed from the content of neurons, perceptrons, and multilayer neural networks in machine learning. Therefore, understanding the basic machine learning algorithms is of great help to the understanding and application of deep learning. The \"deepness\" of deep learning is embodied in a series of vector-based mathematical operations such as convolution and pooling used in the process of processing a large amount of information. If you lack the theoretical foundation of the two, you can learn from teacher Li Hongyi's [Linear Algebra](https://aistudio.baidu.com/aistudio/course/introduce/2063) and [Machine Learning](https://aistudio.baidu.com/aistudio/course/introduce/1978) courses.\n",
"\n",
"For the understanding of deep learning itself, you can refer to the zero-based course of Bai Ran, an outstanding architect of Baidu: [Baidu architects take you hands-on with zero-based practice deep learning](https://aistudio.baidu.com/aistudio/course/introduce/1297), which covers the development history of deep learning and introduces the complete components of deep learning through a classic case. It is a set of practice-oriented deep learning courses.\n",
"\n",
"For the practice of theoretical knowledge, [Python basic knowledge](https://aistudio.baidu.com/aistudio/course/introduce/1224) is essential. At the same time, in order to quickly reproduce the deep learning model, the deep learning framework used in this course For: Flying PaddlePaddle. If you have used other frameworks, you can quickly learn how to use flying paddles through [Quick Start Document](https://www.paddlepaddle.org.cn/documentation/docs/zh/practices/quick_start/hello_paddle.html).\n",
"\n",
"### 1.2 Basic Environment Preparation\n",
"\n",
"If you want to run the code of this course in a local environment and have not built a Python environment before, you can follow the [zero-base operating environment preparation](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/environment.md), install Anaconda or docker environment according to your operating system.\n",
"\n",
"If you don't have local resources, you can run the code through the AI Studio training platform. Each item in it is presented in a notebook, which is convenient for developers to learn. If you are not familiar with the related operations of Notebook, you can refer to [AI Studio Project Description](https://ai.baidu.com/ai-doc/AISTUDIO/0k3e2tfzm).\n",
"\n",
"### 1.3 Get and Run the Code\n",
"\n",
"This course relies on the formation of PaddleOCR's code repository. First, clone the complete project of PaddleOCR:\n",
"\n",
"```bash\n",
"# [recommend]\n",
"git clone https://github.com/PaddlePaddle/PaddleOCR\n",
"\n",
"# If you cannot pull successfully due to network problems, you can also choose to use the hosting on Code Cloud:\n",
"git clone https://gitee.com/paddlepaddle/PaddleOCR\n",
"```\n",
"\n",
"> Note: The code cloud hosted code may not be able to synchronize the update of this github project in real time, there is a delay of 3~5 days, please use the recommended method first.\n",
">\n",
"> If you are not familiar with git operations, you can download the compressed package directly from the `Code` on the homepage of PaddleOCR\n",
"\n",
"Then install third-party libraries:\n",
"\n",
"```bash\n",
"cd PaddleOCR\n",
"pip3 install -r requirements.txt\n",
"```\n",
"\n",
"\n",
"\n",
"### 1.4 Access to Information\n",
"\n",
"[PaddleOCR Usage Document](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/README.md) describes in detail how to use PaddleOCR to complete model application, training and deployment. The document is rich in content, most of the user’s questions are described in the document or FAQ, especially in [FAQ](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/doc/doc_en/FAQ_en.md), in accordance with the application process of deep learning, has precipitated the user's common questions, it is recommended that you read it carefully.\n",
"\n",
"### 1.5 Ask for Help\n",
"\n",
"If you encounter BUG, ease of use or documentation related issues while using PaddleOCR, you can contact the official via [Github issue](https://github.com/PaddlePaddle/PaddleOCR/issues), please follow the issue template Provide as much information as possible so that official personnel can quickly locate the problem. At the same time, the WeChat group is the daily communication position for the majority of PaddleOCR users, and it is more suitable for asking some consulting questions. In addition to the PaddleOCR team members, there will also be enthusiastic developers answering your questions."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "py35-paddle1.2.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
...@@ -31,7 +31,8 @@ class CTCLoss(nn.Layer): ...@@ -31,7 +31,8 @@ class CTCLoss(nn.Layer):
predicts = predicts[-1] predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2)) predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64') preds_lengths = paddle.to_tensor(
[N] * B, dtype='int64', place=paddle.CPUPlace())
labels = batch[1].astype("int32") labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64') label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths) loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
class ClsMetric(object): class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred_label, *args, **kwargs): def __call__(self, pred_label, *args, **kwargs):
...@@ -28,7 +29,7 @@ class ClsMetric(object): ...@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num += 1 all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num / all_num, } return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -36,7 +37,7 @@ class ClsMetric(object): ...@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0 'acc': 0
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
...@@ -20,6 +20,7 @@ class RecMetric(object): ...@@ -20,6 +20,7 @@ class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs): def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter self.is_filter = is_filter
self.eps = 1e-5
self.reset() self.reset()
def _normalize_text(self, text): def _normalize_text(self, text):
...@@ -47,8 +48,8 @@ class RecMetric(object): ...@@ -47,8 +48,8 @@ class RecMetric(object):
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
return { return {
'acc': correct_num / all_num, 'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
} }
def get_metric(self): def get_metric(self):
...@@ -58,8 +59,8 @@ class RecMetric(object): ...@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = 1.0 * self.correct_num / (self.all_num + 1e-3) acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
class TableMetric(object): class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred, batch, *args, **kwargs):
...@@ -31,9 +34,7 @@ class TableMetric(object): ...@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num += 1 correct_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return { return {'acc': correct_num * 1.0 / (all_num + self.eps), }
'acc': correct_num * 1.0 / all_num,
}
def get_metric(self): def get_metric(self):
""" """
...@@ -41,7 +42,7 @@ class TableMetric(object): ...@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0, 'acc': 0,
} }
""" """
acc = 1.0 * self.correct_num / self.all_num acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
...@@ -105,3 +105,22 @@ def set_seed(seed=1024): ...@@ -105,3 +105,22 @@ def set_seed(seed=1024):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
paddle.seed(seed) paddle.seed(seed)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
"""reset"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""update"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
...@@ -133,16 +133,16 @@ cd ppstructure ...@@ -133,16 +133,16 @@ cd ppstructure
# 下载模型 # 下载模型
mkdir inference && cd inference mkdir inference && cd inference
# 下载超轻量级中文OCR模型的检测模型并解压 # 下载PP-OCRv2文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
# 下载超轻量级中文OCR模型的识别模型并解压 # 下载PP-OCRv2文本识别模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar
# 下载超轻量级英文表格英寸模型并解压 # 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \ --rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \ --image_dir=../doc/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
......
...@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# run # run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`. Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
......
...@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# 执行预测 # 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下 运行完成后,每张图片的excel表格会保存到output字段指定的目录下
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Linux端基础训练预测功能测试的主程序为test_train_python.sh,可以测试基于Python的模型训练、评估等基本功能,包括裁剪、量化、蒸馏训练。 Linux端基础训练预测功能测试的主程序为test_train_python.sh,可以测试基于Python的模型训练、评估等基本功能,包括裁剪、量化、蒸馏训练。
![](./tipc_train.png) ![](./test_tipc/tipc_train.png)
测试链条如上图所示,主要测试内容有带共享权重,自定义OP的模型的正常训练和slim相关功能训练流程是否正常。 测试链条如上图所示,主要测试内容有带共享权重,自定义OP的模型的正常训练和slim相关功能训练流程是否正常。
...@@ -28,23 +28,23 @@ pip3 install -r requirements.txt ...@@ -28,23 +28,23 @@ pip3 install -r requirements.txt
- 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度; - 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
``` ```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer' bash test_tipc/test_train_python.sh ./test_tipc/train_infer_python.txt 'lite_train_lite_infer'
``` ```
- 模式2:whole_train_whole_infer,使用全量数据训练,用于快速验证训练到预测的走通流程,验证模型最终训练精度; - 模式2:whole_train_whole_infer,使用全量数据训练,用于快速验证训练到预测的走通流程,验证模型最终训练精度;
``` ```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer' bash test_tipc/test_train_python.sh ./test_tipc/train_infer_python.txt 'whole_train_whole_infer'
``` ```
如果是运行量化裁剪等训练方式,需要使用不同的配置文件。量化训练的测试指令如下: 如果是运行量化裁剪等训练方式,需要使用不同的配置文件。量化训练的测试指令如下:
``` ```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_PACT.txt 'lite_train_lite_infer' bash test_tipc/test_train_python.sh ./test_tipc/train_infer_python_PACT.txt 'lite_train_lite_infer'
``` ```
同理,FPGM裁剪的运行方式如下: 同理,FPGM裁剪的运行方式如下:
``` ```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_FPGM.txt 'lite_train_lite_infer' bash test_tipc/test_train_python.sh ./test_tipc/train_infer_python_FPGM.txt 'lite_train_lite_infer'
``` ```
运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式运行后,在test_tipc/extra_output文件夹有以下文件: 运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式运行后,在test_tipc/extra_output文件夹有以下文件:
......
...@@ -284,7 +284,6 @@ else ...@@ -284,7 +284,6 @@ else
set_amp_config=" " set_amp_config=" "
fi fi
for trainer in ${trainer_list[*]}; do for trainer in ${trainer_list[*]}; do
eval ${env}
flag_quant=False flag_quant=False
if [ ${trainer} = ${pact_key} ]; then if [ ${trainer} = ${pact_key} ]; then
run_train=${pact_trainer} run_train=${pact_trainer}
...@@ -344,6 +343,7 @@ else ...@@ -344,6 +343,7 @@ else
# run eval # run eval
if [ ${eval_py} != "null" ]; then if [ ${eval_py} != "null" ]; then
eval ${env}
set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}") set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}")
eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}" eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}"
eval $eval_cmd eval $eval_cmd
......
...@@ -24,6 +24,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth' ...@@ -24,6 +24,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import json
import time import time
import logging import logging
from PIL import Image from PIL import Image
...@@ -92,11 +93,11 @@ class TextSystem(object): ...@@ -92,11 +93,11 @@ class TextSystem(object):
self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
rec_res) rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res): for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_reuslt text, score = rec_result
if score >= self.drop_score: if score >= self.drop_score:
filter_boxes.append(box) filter_boxes.append(box)
filter_rec_res.append(rec_reuslt) filter_rec_res.append(rec_result)
return filter_boxes, filter_rec_res return filter_boxes, filter_rec_res
...@@ -128,6 +129,9 @@ def main(args): ...@@ -128,6 +129,9 @@ def main(args):
is_visualize = True is_visualize = True
font_path = args.vis_font_path font_path = args.vis_font_path
drop_score = args.drop_score drop_score = args.drop_score
draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
save_results = []
# warm up 10 times # warm up 10 times
if args.warmup: if args.warmup:
...@@ -157,6 +161,14 @@ def main(args): ...@@ -157,6 +161,14 @@ def main(args):
for text, score in rec_res: for text, score in rec_res:
logger.debug("{}, {:.3f}".format(text, score)) logger.debug("{}, {:.3f}".format(text, score))
res = [{
"transcription": rec_res[idx][0],
"points": np.array(dt_boxes[idx]).astype(np.int32).tolist(),
} for idx in range(len(dt_boxes))]
save_pred = os.path.basename(image_file) + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
save_results.append(save_pred)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
boxes = dt_boxes boxes = dt_boxes
...@@ -170,8 +182,6 @@ def main(args): ...@@ -170,8 +182,6 @@ def main(args):
scores, scores,
drop_score=drop_score, drop_score=drop_score,
font_path=font_path) font_path=font_path)
draw_img_save_dir = args.draw_img_save_dir
os.makedirs(draw_img_save_dir, exist_ok=True)
if flag: if flag:
image_file = image_file[:-3] + "png" image_file = image_file[:-3] + "png"
cv2.imwrite( cv2.imwrite(
...@@ -185,6 +195,9 @@ def main(args): ...@@ -185,6 +195,9 @@ def main(args):
text_sys.text_detector.autolog.report() text_sys.text_detector.autolog.report()
text_sys.text_recognizer.autolog.report() text_sys.text_recognizer.autolog.report()
with open(os.path.join(draw_img_save_dir, "system_results.txt"), 'w') as f:
f.writelines(save_results)
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
......
...@@ -73,8 +73,8 @@ def main(): ...@@ -73,8 +73,8 @@ def main():
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: for rec_result in post_result:
logger.info('\t result: {}'.format(rec_reuslt)) logger.info('\t result: {}'.format(rec_result))
logger.info("success!") logger.info("success!")
......
...@@ -21,7 +21,7 @@ import sys ...@@ -21,7 +21,7 @@ import sys
import platform import platform
import yaml import yaml
import time import time
import shutil import datetime
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from tqdm import tqdm from tqdm import tqdm
...@@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter ...@@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils import profiler from ppocr.utils import profiler
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
...@@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser): ...@@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser):
'--profiler_options', '--profiler_options',
type=str, type=str,
default=None, default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' help='The option of profiler, which should be in format ' \
'\"key1=value1;key2=value2;key3=value3\".'
) )
def parse_args(self, argv=None): def parse_args(self, argv=None):
...@@ -99,7 +99,8 @@ def merge_config(config, opts): ...@@ -99,7 +99,8 @@ def merge_config(config, opts):
sub_keys = key.split('.') sub_keys = key.split('.')
assert ( assert (
sub_keys[0] in config sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( ), "the sub_keys can only be one of global_config: {}, but get: " \
"{}, please check your running command".format(
config.keys(), sub_keys[0]) config.keys(), sub_keys[0])
cur = config[sub_keys[0]] cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
...@@ -145,6 +146,7 @@ def train(config, ...@@ -145,6 +146,7 @@ def train(config,
scaler=None): scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step'] print_batch_step = config['Global']['print_batch_step']
...@@ -160,11 +162,13 @@ def train(config, ...@@ -160,11 +162,13 @@ def train(config,
eval_batch_step = eval_batch_step[1] eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0: if len(valid_dataloader) == 0:
logger.info( logger.info(
'No Images in eval dataset, evaluation during training will be disabled' 'No Images in eval dataset, evaluation during training ' \
'will be disabled'
) )
start_eval_step = 1e111 start_eval_step = 1e111
logger.info( logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations". "During the training process, after the {}th iteration, " \
"an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step)) format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step'] save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
...@@ -189,10 +193,11 @@ def train(config, ...@@ -189,10 +193,11 @@ def train(config,
start_epoch = best_model_dict[ start_epoch = best_model_dict[
'start_epoch'] if 'start_epoch' in best_model_dict else 1 'start_epoch'] if 'start_epoch' in best_model_dict else 1
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0 total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time() reader_start = time.time()
eta_meter = AverageMeter()
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
...@@ -203,7 +208,6 @@ def train(config, ...@@ -203,7 +208,6 @@ def train(config,
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
...@@ -214,7 +218,6 @@ def train(config, ...@@ -214,7 +218,6 @@ def train(config,
if use_srn: if use_srn:
model_average = True model_average = True
train_start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(): with paddle.amp.auto_cast():
...@@ -242,7 +245,19 @@ def train(config, ...@@ -242,7 +245,19 @@ def train(config,
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_run_cost += time.time() - train_start if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
eta_meter.update(train_batch_time)
global_step += 1 global_step += 1
total_samples += len(images) total_samples += len(images)
...@@ -254,16 +269,6 @@ def train(config, ...@@ -254,16 +269,6 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
eval_class(preds, batch)
else:
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metric = eval_class.get_metric()
train_stats.update(metric)
if vdl_writer is not None and dist.get_rank() == 0: if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items(): for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
...@@ -273,19 +278,26 @@ def train(config, ...@@ -273,19 +278,26 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)): (idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format( eta_sec = ((epoch_num + 1 - epoch) * \
epoch, epoch_num, global_step, logs, train_reader_cost / len(train_dataloader) - idx - 1) * eta_meter.avg
print_batch_step, (train_reader_cost + train_run_cost) / eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
print_batch_step, total_samples / print_batch_step, strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
total_samples / (train_reader_cost + train_run_cost)) '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f}, eta: {}'.format(
epoch, epoch_num, global_step, logs,
train_reader_cost / print_batch_step,
train_batch_cost / print_batch_step,
total_samples / print_batch_step,
total_samples / train_batch_cost, eta_sec_format)
logger.info(strs) logger.info(strs)
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0 total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 \
and dist.get_rank() == 0:
if model_average: if model_average:
Model_Average = paddle.incubate.optimizer.ModelAverage( Model_Average = paddle.incubate.optimizer.ModelAverage(
0.15, 0.15,
...@@ -511,7 +523,7 @@ def preprocess(is_train=False): ...@@ -511,7 +523,7 @@ def preprocess(is_train=False):
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl']: if config['Global']['use_visualdl'] and dist.get_rank() == 0:
from visualdl import LogWriter from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册