提交 5e690264 编写于 作者: qq_25193841's avatar qq_25193841

Merge branch 'dygraph' into dy1

......@@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \
QFileDialog, QListWidgetItem, QComboBox, QDialog
QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView
__dir__ = os.path.dirname(os.path.abspath(__file__))
......@@ -242,6 +242,20 @@ class MainWindow(QMainWindow):
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
listLayout.addWidget(self.labelListDock)
# enable labelList drag_drop to adjust bbox order
# 设置选择模式为单选
self.labelList.setSelectionMode(QAbstractItemView.SingleSelection)
# 启用拖拽
self.labelList.setDragEnabled(True)
# 设置接受拖放
self.labelList.viewport().setAcceptDrops(True)
# 设置显示将要被放置的位置
self.labelList.setDropIndicatorShown(True)
# 设置拖放模式为移动项目,如果不设置,默认为复制项目
self.labelList.setDragDropMode(QAbstractItemView.InternalMove)
# 触发放置
self.labelList.model().rowsMoved.connect(self.drag_drop_happened)
# ================== Detection Box ==================
self.BoxList = QListWidget()
......@@ -589,15 +603,23 @@ class MainWindow(QMainWindow):
self.displayLabelOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayLabelOption.triggered.connect(self.togglePaintLabelsOption)
# Add option to enable/disable box index being displayed at the top of bounding boxes
self.displayIndexOption = QAction(getStr('displayIndex'), self)
self.displayIndexOption.setCheckable(True)
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.displayIndexOption.triggered.connect(self.togglePaintIndexOption)
self.labelDialogOption = QAction(getStr('labelDialogOption'), self)
self.labelDialogOption.setShortcut("Ctrl+Shift+L")
self.labelDialogOption.setCheckable(True)
self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.labelDialogOption.triggered.connect(self.speedChoose)
self.autoSaveOption = QAction(getStr('autoSaveMode'), self)
self.autoSaveOption.setCheckable(True)
self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
addActions(self.menus.file,
......@@ -606,7 +628,7 @@ class MainWindow(QMainWindow):
addActions(self.menus.help, (showKeys, showSteps, showInfo))
addActions(self.menus.view, (
self.displayLabelOption, self.labelDialogOption,
self.displayLabelOption, self.displayIndexOption, self.labelDialogOption,
None,
hideAll, showAll, None,
zoomIn, zoomOut, zoomOrg, None,
......@@ -964,9 +986,10 @@ class MainWindow(QMainWindow):
else:
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes
for shape in self.canvas.selectedShapes_hShape:
item = self.shapesToItemsbox[shape] # listitem
text = [(int(p.x()), int(p.y())) for p in shape.points]
item.setText(str(text))
if shape in self.shapesToItemsbox.keys():
item = self.shapesToItemsbox[shape] # listitem
text = [(int(p.x()), int(p.y())) for p in shape.points]
item.setText(str(text))
self.actions.undo.setEnabled(True)
self.setDirty()
......@@ -1040,6 +1063,8 @@ class MainWindow(QMainWindow):
def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
item = HashableQListWidgetItem(shape.label)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked)
......@@ -1083,6 +1108,7 @@ class MainWindow(QMainWindow):
def loadLabels(self, shapes):
s = []
shape_index = 0
for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points:
......@@ -1094,6 +1120,8 @@ class MainWindow(QMainWindow):
shape.addPoint(QPointF(x, y))
shape.difficult = difficult
shape.idx = shape_index
shape_index += 1
# shape.locked = False
shape.close()
s.append(shape)
......@@ -1209,18 +1237,54 @@ class MainWindow(QMainWindow):
self.canvas.deSelectShape()
def labelItemChanged(self, item):
shape = self.itemsToShapes[item]
label = item.text()
if label != shape.label:
shape.label = item.text()
# shape.line_color = generateColorByText(shape.label)
self.setDirty()
elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
shape.difficult = True if item.checkState() == Qt.Unchecked else False
self.setDirty()
else: # User probably changed item visibility
self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
# self.actions.save.setEnabled(True)
# avoid accidentally triggering the itemChanged siganl with unhashable item
# Unknown trigger condition
if type(item) == HashableQListWidgetItem:
shape = self.itemsToShapes[item]
label = item.text()
if label != shape.label:
shape.label = item.text()
# shape.line_color = generateColorByText(shape.label)
self.setDirty()
elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
shape.difficult = True if item.checkState() == Qt.Unchecked else False
self.setDirty()
else: # User probably changed item visibility
self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
# self.actions.save.setEnabled(True)
else:
print('enter labelItemChanged slot with unhashable item: ', item, item.text())
def drag_drop_happened(self):
'''
label list drag drop signal slot
'''
# print('___________________drag_drop_happened_______________')
# should only select single item
for item in self.labelList.selectedItems():
newIndex = self.labelList.indexFromItem(item).row()
# only support drag_drop one item
assert len(self.canvas.selectedShapes) > 0
for shape in self.canvas.selectedShapes:
selectedShapeIndex = shape.idx
if newIndex == selectedShapeIndex:
return
# move corresponding item in shape list
shape = self.canvas.shapes.pop(selectedShapeIndex)
self.canvas.shapes.insert(newIndex, shape)
# update bbox index
self.canvas.updateShapeIndex()
# boxList update simultaneously
item = self.BoxList.takeItem(selectedShapeIndex)
self.BoxList.insertItem(newIndex, item)
# changes happen
self.setDirty()
# Callback functions:
def newShape(self, value=True):
......@@ -1560,6 +1624,7 @@ class MainWindow(QMainWindow):
settings[SETTING_LAST_OPEN_DIR] = ''
settings[SETTING_PAINT_LABEL] = self.displayLabelOption.isChecked()
settings[SETTING_PAINT_INDEX] = self.displayIndexOption.isChecked()
settings[SETTING_DRAW_SQUARE] = self.drawSquaresOption.isChecked()
settings.save()
try:
......@@ -1946,8 +2011,16 @@ class MainWindow(QMainWindow):
self.labelHist.append(line)
def togglePaintLabelsOption(self):
self.displayIndexOption.setChecked(False)
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
def togglePaintIndexOption(self):
self.displayLabelOption.setChecked(False)
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
def toogleDrawSquare(self):
self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked())
......@@ -2187,6 +2260,7 @@ class MainWindow(QMainWindow):
shapes = []
result_len = len(region['res']['boxes'])
order_index = 0
for i in range(result_len):
bbox = np.array(region['res']['boxes'][i])
rec_text = region['res']['rec_res'][i][0]
......@@ -2205,6 +2279,8 @@ class MainWindow(QMainWindow):
x, y, snapped = self.canvas.snapPointToCanvas(x, y)
shape.addPoint(QPointF(x, y))
shape.difficult = False
shape.idx = order_index
order_index += 1
# shape.locked = False
shape.close()
self.addLabel(shape)
......
......@@ -2,7 +2,7 @@ English | [简体中文](README_ch.md)
# PPOCRLabel
PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PPOCR model to automatically detect and re-recognize data. It is written in python3 and pyqt5, supporting rectangular box, table and multi-point annotation modes. Annotations can be directly used for the training of PPOCR detection and recognition models.
PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PP-OCR model to automatically detect and re-recognize data. It is written in python3 and pyqt5, supporting rectangular box, table and multi-point annotation modes. Annotations can be directly used for the training of PP-OCR detection and recognition models.
<img src="./data/gif/steps_en.gif" width="100%"/>
......@@ -142,14 +142,18 @@ In PPOCRLabel, complete the text information labeling (text and position), compl
labeling in the Excel file, the recommended steps are:
1. Table annotation: After opening the table picture, click on the `Table Recognition` button in the upper right corner of PPOCRLabel, which will call the table recognition model in PP-Structure to automatically label
the table and pop up Excel at the same time.
the table and pop up Excel at the same time.
2. Change the recognition result: **label each cell** (i.e. the text in a cell is marked as a box). Right click on the box and click on `Cell Re-recognition`.
You can use the model to automatically recognise the text within a cell.
3. Mark the table structure: for each cell contains the text, **mark as any identifier (such as `1`) in Excel**, to ensure that the merged cell structure is same as the original picture.
4. Export JSON format annotation: close all Excel files corresponding to table images, click `File`-`Export table JSON annotation` to obtain JSON annotation results.
> Note: If there are blank cells in the table, you also need to mark them with a bounding box so that the total number of cells is the same as in the image.
4. ***Adjust cell order:*** Click on the menu `View` - `Show Box Number` to show the box ordinal numbers, and drag all the results under the 'Recognition Results' column on the right side of the software interface to make the box numbers are arranged from left to right, top to bottom
5. Export JSON format annotation: close all Excel files corresponding to table images, click `File`-`Export table JSON annotation` to obtain JSON annotation results.
### 2.3 Note
......@@ -219,14 +223,7 @@ PPOCRLabel supports three ways to export Label.txt
- Close application export
### 3.4 Export Partial Recognition Results
For some data that are difficult to recognize, the recognition results will not be exported by **unchecking** the corresponding tags in the recognition results checkbox. The unchecked recognition result is saved as `True` in the `difficult` variable in the label file `label.txt`.
> *Note: The status of the checkboxes in the recognition results still needs to be saved manually by clicking Save Button.*
### 3.5 Dataset division
### 3.4 Dataset division
- Enter the following command in the terminal to execute the dataset division script:
......@@ -255,7 +252,7 @@ For some data that are difficult to recognize, the recognition results will not
| ...
```
### 3.6 Error message
### 3.5 Error message
- If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated.
......
......@@ -7,8 +7,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
<img src="./data/gif/steps.gif" width="100%"/>
#### 近期更新
- 2022.05:新增表格标注,使用方法见下方`2.2 表格标注`(by [whjdark](https://github.com/peterh0323); [Evezerest](https://github.com/Evezerest))
- 2022.02:新增关键信息标注、优化标注体验(by [PeterH0323](https://github.com/peterh0323)
- 2022.05:**新增表格标注**,使用方法见下方`2.2 表格标注`(by [whjdark](https://github.com/peterh0323); [Evezerest](https://github.com/Evezerest))
- 2022.02:**新增关键信息标注**、优化标注体验(by [PeterH0323](https://github.com/peterh0323)
- 新增:使用 `--kie` 进入 KIE 功能,用于打【检测+识别+关键字提取】的标签
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题。
- 新增功能:使用 `C``X` 对标记框进行旋转。
......@@ -113,23 +113,29 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.
1. 安装与运行:使用上述命令安装与运行程序。
2. 打开文件夹:在菜单栏点击 “文件” - "打开目录" 选择待标记图片的文件夹<sup>[1]</sup>.
3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态<sup>[2]</sup>为 “X” 的图片进行自动标注。
3. 自动标注:点击 ”自动标注“,使用PP-OCR超轻量模型对图片文件名前图片状态<sup>[2]</sup>为 “X” 的图片进行自动标注。
4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PP-OCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>
7. 内容更改:单击识别结果,对不准确的识别结果进行手动更改。
8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。**
9. 删除:点击 “删除图像”,图片将会被删除至回收站。
10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup>
### 2.2 表格标注
表格标注针对表格的结构化提取,将图片中的表格转换为Excel格式,因此标注时需要配合外部软件打开Excel同时完成。
在PPOCRLabel软件中完成表格中的文字信息标注(文字与位置)、在Excel文件中完成表格结构信息标注,推荐的步骤为:
表格标注针对表格的结构化提取,将图片中的表格转换为Excel格式,因此标注时需要配合外部软件打开Excel同时完成。在PPOCRLabel软件中完成表格中的文字信息标注(文字与位置)、在Excel文件中完成表格结构信息标注,推荐的步骤为:
1. 表格识别:打开表格图片后,点击软件右上角 `表格识别` 按钮,软件调用PP-Structure中的表格识别模型,自动为表格打标签,同时弹出Excel
2. 更改识别结果:**以表格中的单元格为单位增加标注框**(即一个单元格内的文字都标记为一个框)。标注框上鼠标右键后点击 `单元格重识别`
2. 更改标注结果:**以表格中的单元格为单位增加标注框**(即一个单元格内的文字都标记为一个框)。标注框上鼠标右键后点击 `单元格重识别`
可利用模型自动识别单元格内的文字。
3. 标注表格结构:将表格图像中有文字的单元格,**在Excel中标记为任意标识符(如`1`)**,保证Excel中的单元格合并情况与原图相同即可。
4. 导出JSON格式:关闭所有表格图像对应的Excel,点击 `文件`-`导出表格JSON标注` 获得JSON标注结果。
> 注意:如果表格中存在空白单元格,同样需要使用一个标注框将其标出,使得单元格总数与图像中保持一致。
3. **调整单元格顺序:**点击软件`视图-显示框编号` 打开标注框序号,在软件界面右侧拖动 `识别结果` 一栏下的所有结果,使得标注框编号按照从左到右,从上到下的顺序排列
4. 标注表格结构:**在外部Excel软件中,将存在文字的单元格标记为任意标识符(如 `1` )**,保证Excel中的单元格合并情况与原图相同即可(即不需要Excel中的单元格文字与图片中的文字完全相同)
5. 导出JSON格式:关闭所有表格图像对应的Excel,点击 `文件`-`导出表格JSON标注` 获得JSON标注结果。
### 2.3 注意
......@@ -197,13 +203,7 @@ PPOCRLabel支持三种导出方式:
- 关闭应用程序导出
### 3.4 导出部分识别结果
针对部分难以识别的数据,通过在识别结果的复选框中**取消勾选**相应的标记,其识别结果不会被导出。被取消勾选的识别结果在标记文件 `label.txt` 中的 `difficult` 变量保存为 `True`
> *注意:识别结果中的复选框状态仍需用户手动点击确认后才能保留*
### 3.5 数据集划分
### 3.4 数据集划分
在终端中输入以下命令执行数据集划分脚本:
......@@ -232,7 +232,7 @@ python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../
| ...
```
### 3.6 错误提示
### 3.5 错误提示
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
......
......@@ -314,21 +314,23 @@ class Canvas(QWidget):
QApplication.restoreOverrideCursor() # ?
if self.movingShape and self.hShape:
index = self.shapes.index(self.hShape)
if (
self.shapesBackups[-1][index].points
!= self.shapes[index].points
):
self.storeShapes()
self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
if self.hShape in self.shapes:
index = self.shapes.index(self.hShape)
if (
self.shapesBackups[-1][index].points
!= self.shapes[index].points
):
self.storeShapes()
self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
self.movingShape = False
self.movingShape = False
def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
if copy:
for i, shape in enumerate(self.selectedShapesCopy):
shape.idx = len(self.shapes) # add current box index
self.shapes.append(shape)
self.selectedShapes[i].selected = False
self.selectedShapes[i] = shape
......@@ -524,6 +526,9 @@ class Canvas(QWidget):
self.storeShapes()
self.selectedShapes = []
self.update()
self.updateShapeIndex()
return deleted_shapes
def storeShapes(self):
......@@ -619,6 +624,13 @@ class Canvas(QWidget):
pal.setColor(self.backgroundRole(), QColor(232, 232, 232, 255))
self.setPalette(pal)
# adaptive BBOX label & index font size
if self.pixmap:
h, w = self.pixmap.size().height(), self.pixmap.size().width()
fontszie = int(max(h, w) / 48)
for s in self.shapes:
s.fontsize = fontszie
p.end()
def fillDrawing(self):
......@@ -651,7 +663,8 @@ class Canvas(QWidget):
return
self.current.close()
self.shapes.append(self.current)
self.current.idx = len(self.shapes) # add current box index
self.shapes.append(self.current)
self.current = None
self.setHiding(False)
self.newShape.emit()
......@@ -842,6 +855,7 @@ class Canvas(QWidget):
self.hVertex = None
# self.hEdge = None
self.storeShapes()
self.updateShapeIndex()
self.repaint()
def setShapeVisible(self, shape, value):
......@@ -883,10 +897,16 @@ class Canvas(QWidget):
self.selectedShapes = []
for shape in self.shapes:
shape.selected = False
self.updateShapeIndex()
self.repaint()
@property
def isShapeRestorable(self):
if len(self.shapesBackups) < 2:
return False
return True
\ No newline at end of file
return True
def updateShapeIndex(self):
for i in range(len(self.shapes)):
self.shapes[i].idx = i
self.update()
\ No newline at end of file
......@@ -21,6 +21,7 @@ SETTING_ADVANCE_MODE = 'advanced'
SETTING_WIN_STATE = 'window/state'
SETTING_SAVE_DIR = 'savedir'
SETTING_PAINT_LABEL = 'paintlabel'
SETTING_PAINT_INDEX = 'paintindex'
SETTING_LAST_OPEN_DIR = 'lastOpenDir'
SETTING_AUTO_SAVE = 'autosave'
SETTING_SINGLE_CLASS = 'singleclass'
......
......@@ -26,4 +26,4 @@ class EditInList(QListWidget):
def leaveEvent(self, event):
# close edit
for i in range(self.count()):
self.closePersistentEditor(self.item(i))
self.closePersistentEditor(self.item(i))
\ No newline at end of file
......@@ -46,15 +46,16 @@ class Shape(object):
point_size = 8
scale = 1.0
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False):
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False, paintIdx=False):
self.label = label
self.idx = 0
self.idx = None # bbox order, only for table annotation
self.points = []
self.fill = False
self.selected = False
self.difficult = difficult
self.key_cls = key_cls
self.paintLabel = paintLabel
self.paintIdx = paintIdx
self.locked = False
self.direction = 0
self.center = None
......@@ -65,6 +66,7 @@ class Shape(object):
self.NEAR_VERTEX: (4, self.P_ROUND),
self.MOVE_VERTEX: (1.5, self.P_SQUARE),
}
self.fontsize = 8
self._closed = False
......@@ -155,7 +157,7 @@ class Shape(object):
min_y = min(min_y, point.y())
if min_x != sys.maxsize and min_y != sys.maxsize:
font = QFont()
font.setPointSize(8)
font.setPointSize(self.fontsize)
font.setBold(True)
painter.setFont(font)
if self.label is None:
......@@ -164,6 +166,25 @@ class Shape(object):
min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, self.label)
# Draw number at the top-right
if self.paintIdx:
min_x = sys.maxsize
min_y = sys.maxsize
for point in self.points:
min_x = min(min_x, point.x())
min_y = min(min_y, point.y())
if min_x != sys.maxsize and min_y != sys.maxsize:
font = QFont()
font.setPointSize(self.fontsize)
font.setBold(True)
painter.setFont(font)
text = ''
if self.idx != None:
text = str(self.idx)
if min_y < MIN_Y_LABEL:
min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, text)
if self.fill:
color = self.select_fill_color if self.selected else self.fill_color
painter.fillPath(line_path, color)
......
......@@ -61,6 +61,7 @@ labels=Labels
autoSaveMode=Auto Save mode
singleClsMode=Single Class Mode
displayLabel=Display Labels
displayIndex=Display box index
fileList=File List
files=Files
advancedMode=Advanced Mode
......
......@@ -61,6 +61,7 @@ labels=标签
autoSaveMode=自动保存模式
singleClsMode=单一类别模式
displayLabel=显示类别
displayIndex=显示box序号
fileList=文件列表
files=文件
advancedMode=专家模式
......
......@@ -72,6 +72,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
</div>
<a name="Supported-Chinese-model-list"></a>
## PP-OCR Series Model List(Update on September 8th)
| Model introduction | Model name | Recommended scene | Detection model | Direction classifier | Recognition model |
......
......@@ -71,6 +71,8 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
## 《动手学OCR》电子书
- [《动手学OCR》电子书📚](./doc/doc_ch/ocr_book.md)
## 场景应用
- PaddleOCR场景应用覆盖通用,制造、金融、交通行业的主要OCR垂类应用,在PP-OCR、PP-Structure的通用能力基础之上,以notebook的形式展示利用场景数据微调、模型优化方法、数据增广等内容,为开发者快速落地OCR应用提供示范与启发。详情可查看[README](./applications)
<a name="开源社区"></a>
## 开源社区
......
# 一种基于PaddleOCR的产品包装生产日期识别模型
- [1. 项目介绍](#1-项目介绍)
- [2. 环境搭建](#2-环境搭建)
- [3. 数据准备](#3-数据准备)
- [4. 直接使用PP-OCRv3模型评估](#4-直接使用PPOCRv3模型评估)
- [5. 基于合成数据finetune](#5-基于合成数据finetune)
- [5.1 Text Renderer数据合成方法](#51-TextRenderer数据合成方法)
- [5.1.1 下载Text Renderer代码](#511-下载TextRenderer代码)
- [5.1.2 准备背景图片](#512-准备背景图片)
- [5.1.3 准备语料](#513-准备语料)
- [5.1.4 下载字体](#514-下载字体)
- [5.1.5 运行数据合成命令](#515-运行数据合成命令)
- [5.2 模型训练](#52-模型训练)
- [6. 基于真实数据finetune](#6-基于真实数据finetune)
- [6.1 python爬虫获取数据](#61-python爬虫获取数据)
- [6.2 数据挖掘](#62-数据挖掘)
- [6.3 模型训练](#63-模型训练)
- [7. 基于合成+真实数据finetune](#7-基于合成+真实数据finetune)
## 1. 项目介绍
产品包装生产日期是计算机视觉图像识别技术在工业场景中的一种应用。产品包装生产日期识别技术要求能够将产品生产日期从复杂背景中提取并识别出来,在物流管理、物资管理中得到广泛应用。
![](https://ai-studio-static-online.cdn.bcebos.com/d9e0533cc1df47ffa3bbe99de9e42639a3ebfa5bce834bafb1ca4574bf9db684)
- 项目难点
1. 没有训练数据
2. 图像质量层次不齐: 角度倾斜、图片模糊、光照不足、过曝等问题严重
针对以上问题, 本例选用PP-OCRv3这一开源超轻量OCR系统进行包装产品生产日期识别系统的开发。直接使用PP-OCRv3进行评估的精度为62.99%。为提升识别精度,我们首先使用数据合成工具合成了3k数据,基于这部分数据进行finetune,识别精度提升至73.66%。由于合成数据与真实数据之间的分布存在差异,为进一步提升精度,我们使用网络爬虫配合数据挖掘策略得到了1k带标签的真实数据,基于真实数据finetune的精度为71.33%。最后,我们综合使用合成数据和真实数据进行finetune,将识别精度提升至86.99%。各策略的精度提升效果如下:
| 策略 | 精度|
| :--------------- | :-------- |
| PP-OCRv3评估 | 62.99|
| 合成数据finetune | 73.66|
| 真实数据finetune | 71.33|
| 真实+合成数据finetune | 86.99|
AIStudio项目链接: [一种基于PaddleOCR的包装生产日期识别方法](https://aistudio.baidu.com/aistudio/projectdetail/4287736)
## 2. 环境搭建
本任务基于Aistudio完成, 具体环境如下:
- 操作系统: Linux
- PaddlePaddle: 2.3
- PaddleOCR: Release/2.5
- text_renderer: master
下载PaddlleOCR代码并安装依赖库:
```bash
git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR
# 安装依赖库
cd PaddleOCR
pip install -r PaddleOCR/requirements.txt
```
## 3. 数据准备
本项目使用人工预标注的300张图像作为测试集。
部分数据示例如下:
![](https://ai-studio-static-online.cdn.bcebos.com/39ff30e0ab0442579712255e6a9ea6b5271169c98e624e6eb2b8781f003bfea0)
标签文件格式如下:
```txt
数据路径 标签(中间以制表符分隔)
```
|数据集类型|数量|
|---|---|
|测试集| 300|
数据集[下载链接](https://aistudio.baidu.com/aistudio/datasetdetail/149770),下载后可以通过下方命令解压:
```bash
tar -xvf data.tar
mv data ${PaddleOCR_root}
```
数据解压后的文件结构如下:
```shell
PaddleOCR
├── data
│ ├── mining_images # 挖掘的真实数据示例
│ ├── mining_train.list # 挖掘的真实数据文件列表
│ ├── render_images # 合成数据示例
│ ├── render_train.list # 合成数据文件列表
│ ├── val # 测试集数据
│ └── val.list # 测试集数据文件列表
| ├── bg # 合成数据所需背景图像
│ └── corpus # 合成数据所需语料
```
## 4. 直接使用PP-OCRv3模型评估
准备好测试数据后,可以使用PaddleOCR的PP-OCRv3模型进行识别。
- 下载预训练模型
首先需要下载PP-OCR v3中英文识别模型文件,下载链接可以在https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/ppocr_introduction.md#6 获取,下载命令:
```bash
cd ${PaddleOCR_root}
mkdir ckpt
wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
pushd ckpt/
tar -xvf ch_PP-OCRv3_rec_train.tar
popd
```
- 模型评估
使用以下命令进行PP-OCRv3评估:
```bash
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.checkpoints=ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"]
```
其中各参数含义如下:
```bash
-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。
-o: 覆盖配置文件中参数
Global.checkpoints: 指定评估使用的模型文件路径
Eval.dataset.data_dir: 指定评估数据集路径
Eval.dataset.label_file_list: 指定评估数据集文件列表
```
## 5. 基于合成数据finetune
### 5.1 Text Renderer数据合成方法
#### 5.1.1 下载Text Renderer代码
首先从github或gitee下载Text Renderer代码,并安装相关依赖。
```bash
git clone https://gitee.com/wowowoll/text_renderer.git
# 安装依赖库
cd text_renderer
pip install -r requirements.txt
```
使用text renderer合成数据之前需要准备好背景图片、语料以及字体库,下面将逐一介绍各个步骤。
#### 5.1.2 准备背景图片
观察日常生活中常见的包装生产日期图片,我们可以发现其背景相对简单。为此我们可以从网上找一下图片,截取部分图像块作为背景图像。
本项目已准备了部分图像作为背景图片,在第3部分完成数据准备后,可以得到我们准备好的背景图像,示例如下:
![](https://ai-studio-static-online.cdn.bcebos.com/456ae2acb27d4a94896c478812aee0bc3551c703d7bd40c9be4dc983c7b3fc8a)
背景图像存放于如下位置:
```shell
PaddleOCR
├── data
| ├── bg # 合成数据所需背景图像
```
#### 5.1.3 准备语料
观察测试集生产日期图像,我们可以知道如下数据有如下特点:
1. 由年月日组成,中间可能以“/”、“-”、“:”、“.”或者空格间隔,也可能以汉字年月日分隔
2. 有些生产日期包含在产品批号中,此时可能包含具体时间、英文字母或数字标识
基于以上两点,我们编写语料生成脚本:
```python
import random
from random import choice
import os
cropus_num = 2000 #设置语料数量
def get_cropus(f):
# 随机生成年份
year = random.randint(0, 22)
# 随机生成月份
month = random.randint(1, 12)
# 随机生成日期
day_dict = {31: [1,3,5,7,8,10,12], 30: [4,6,9,11], 28: [2]}
for item in day_dict:
if month in day_dict[item]:
day = random.randint(0, item)
# 随机生成小时
hours = random.randint(0, 24)
# 随机生成分钟
minute = random.randint(0, 60)
# 随机生成秒数
second = random.randint(0, 60)
# 随机生成产品标识字符
length = random.randint(0, 6)
file_id = []
flag = 0
my_dict = [i for i in range(48,58)] + [j for j in range(40, 42)] + [k for k in range(65,90)] # 大小写字母 + 括号
for i in range(1, length):
if flag:
if i == flag+2: #括号匹配
file_id.append(')')
flag = 0
continue
sel = choice(my_dict)
if sel == 41:
continue
if sel == 40:
if i == 1 or i > length-3:
continue
flag = i
my_ascii = chr(sel)
file_id.append(my_ascii)
file_id_str = ''.join(file_id)
#随机生成产品标识字符
file_id2 = random.randint(0, 9)
rad = random.random()
if rad < 0.3:
f.write('20{:02d}{:02d}{:02d} {}'.format(year, month, day, file_id_str))
elif 0.3 < rad < 0.5:
f.write('20{:02d}年{:02d}月{:02d}日'.format(year, month, day))
elif 0.5 < rad < 0.7:
f.write('20{:02d}/{:02d}/{:02d}'.format(year, month, day))
elif 0.7 < rad < 0.8:
f.write('20{:02d}-{:02d}-{:02d}'.format(year, month, day))
elif 0.8 < rad < 0.9:
f.write('20{:02d}.{:02d}.{:02d}'.format(year, month, day))
else:
f.write('{:02d}:{:02d}:{:02d} {:02d}'.format(hours, minute, second, file_id2))
if __name__ == "__main__":
file_path = '/home/aistudio/text_renderer/my_data/cropus'
if not os.path.exists(file_path):
os.makedirs(file_path)
file_name = os.path.join(file_path, 'books.txt')
f = open(file_name, 'w')
for i in range(cropus_num):
get_cropus(f)
if i < cropus_num-1:
f.write('\n')
f.close()
```
本项目已准备了部分语料,在第3部分完成数据准备后,可以得到我们准备好的语料库,默认位置如下:
```shell
PaddleOCR
├── data
│ └── corpus #合成数据所需语料
```
#### 5.1.4 下载字体
观察包装生产日期,我们可以发现其使用的字体为点阵体。字体可以在如下网址下载:
https://www.fonts.net.cn/fonts-en/tag-dianzhen-1.html
本项目已准备了部分字体,在第3部分完成数据准备后,可以得到我们准备好的字体,默认位置如下:
```shell
PaddleOCR
├── data
│ └── fonts #合成数据所需字体
```
下载好字体后,还需要在list文件中指定字体文件存放路径,脚本如下:
```bash
cd text_renderer/my_data/
touch fonts.list
ls /home/aistudio/PaddleOCR/data/fonts/* > fonts.list
```
#### 5.1.5 运行数据合成命令
完成数据准备后,my_data文件结构如下:
```shell
my_data/
├── cropus
│ └── books.txt #语料库
├── eng.txt #字符列表
└── fonts.list #字体列表
```
在运行合成数据命令之前,还有两处细节需要手动修改:
1. 将默认配置文件`text_renderer/configs/default.yaml`中第9行enable的值设为`true`,即允许合成彩色图像。否则合成的都是灰度图。
```yaml
# color boundary is in R,G,B format
font_color:
+ enable: true #false
```
2.`text_renderer/textrenderer/renderer.py`第184行作如下修改,取消padding。否则图片两端会有一些空白。
```python
padding = random.randint(s_bbox_width // 10, s_bbox_width // 8) #修改前
padding = 0 #修改后
```
运行数据合成命令:
```bash
cd /home/aistudio/text_renderer/
python main.py --num_img=3000 \
--fonts_list='./my_data/fonts.list' \
--corpus_dir "./my_data/cropus" \
--corpus_mode "list" \
--bg_dir "/home/aistudio/PaddleOCR/data/bg/" \
--img_width 0
```
合成好的数据默认保存在`text_renderer/output`目录下,可进入该目录查看合成的数据。
合成数据示例如下
![](https://ai-studio-static-online.cdn.bcebos.com/d686a48d465a43d09fbee51924fdca42ee21c50e676646da8559fb9967b94185)
数据合成好后,还需要生成如下格式的训练所需的标注文件,
```
图像路径 标签
```
使用如下脚本即可生成标注文件:
```python
import random
abspath = '/home/aistudio/text_renderer/output/default/'
#标注文件生成路径
fout = open('./render_train.list', 'w', encoding='utf-8')
with open('./output/default/tmp_labels.txt','r') as f:
lines = f.readlines()
for item in lines:
label = item[9:]
filename = item[:8] + '.jpg'
fout.write(abspath + filename + '\t' + label)
fout.close()
```
经过以上步骤,我们便完成了包装生产日期数据合成。
数据位于`text_renderer/output`,标注文件位于`text_renderer/render_train.list`
本项目提供了生成好的数据供大家体验,完成步骤3的数据准备后,可得数据路径位于:
```shell
PaddleOCR
├── data
│ ├── render_images # 合成数据示例
│ ├── render_train.list #合成数据文件列表
```
### 5.2 模型训练
准备好合成数据后,我们可以使用以下命令,利用合成数据进行finetune:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=20 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/render_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
其中各参数含义如下:
```txt
-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。
-o: 覆盖配置文件中参数
Global.pretrained_model: 指定finetune使用的预训练模型
Global.epoch_num: 指定训练的epoch数
Global.eval_batch_step: 间隔多少step做一次评估
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
Train.loader.batch_size_per_card: 训练单卡batch size
Eval.dataset.data_dir: 评估数据集路径
Eval.dataset.label_file_list: 评估数据集文件列表
Eval.loader.batch_size_per_card: 评估单卡batch size
```
## 6. 基于真实数据finetune
使用合成数据finetune能提升我们模型的识别精度,但由于合成数据和真实数据之间的分布可能有一定差异,因此作用有限。为进一步提高识别精度,本节介绍如何挖掘真实数据进行模型finetune。
数据挖掘的整体思路如下:
1. 使用python爬虫从网上获取大量无标签数据
2. 使用模型从大量无标签数据中构建出有效训练集
### 6.1 python爬虫获取数据
- 推荐使用[爬虫工具](https://github.com/Joeclinton1/google-images-download)获取无标签图片。
图片获取后,可按如下目录格式组织:
```txt
sprider
├── file.list
├── data
│ ├── 00000.jpg
│ ├── 00001.jpg
...
```
### 6.2 数据挖掘
我们使用PaddleOCR对获取到的图片进行挖掘,具体步骤如下:
1. 使用 PP-OCRv3检测模型+svtr-tiny识别模型,对每张图片进行预测。
2. 使用数据挖掘策略,得到有效图片。
3. 将有效图片对应的图像区域和标签提取出来,构建训练集。
首先下载预训练模型,PP-OCRv3检测模型下载链接:https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
如需获取svtr-tiny高精度中文识别预训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<div align="left">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
完成下载后,可将模型存储于如下位置:
```shell
PaddleOCR
├── data
│ ├── rec_vit_sub_64_363_all/ # svtr_tiny高精度识别模型
```
```bash
# 下载解压PP-OCRv3检测模型
cd ${PaddleOCR_root}
wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
pushd ckpt
tar -xvf ch_PP-OCRv3_det_infer.tar
popd ckpt
```
在使用PPOCRv3检测模型+svtr-tiny识别模型进行预测之前,有如下两处细节需要手动修改:
1.`tools/infer/predict_rec.py`中第110行`imgW`修改为`320`
```python
#imgW = int((imgH * max_wh_ratio))
imgW = 320
```
2.`tools/infer/predict_system.py`第169行添加如下一行,将预测分数也写入结果文件中。
```python
"scores": rec_res[idx][1],
```
模型预测命令:
```bash
python tools/infer/predict_system.py \
--image_dir="/home/aistudio/sprider/data" \
--det_model_dir="./ckpt/ch_PP-OCRv3_det_infer/" \
--rec_model_dir="/home/aistudio/PaddleOCR/data/rec_vit_sub_64_363_all/" \
--rec_image_shape="3,32,320"
```
获得预测结果后,我们使用数据挖掘策略得到有效图片。具体挖掘策略如下:
1. 预测置信度高于95%
2. 识别结果包含字符‘20’,即年份
3. 没有中文,或者有中文并且‘日’和'月'同时在识别结果中
```python
# 获取有效预测
import json
import re
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') #正则表达式,筛选字符是否包含中文
file_path = '/home/aistudio/PaddleOCR/inference_results/system_results.txt'
out_path = '/home/aistudio/PaddleOCR/selected_results.txt'
f_out = open(out_path, 'w')
with open(file_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
flag = False
# 读取文件内容
file_name, json_file = line.strip().split('\t')
preds = json.loads(json_file)
res = []
for item in preds:
transcription = item['transcription'] #获取识别结果
scores = item['scores'] #获取识别得分
# 挖掘策略
if scores > 0.95:
if '20' in transcription and len(transcription) > 4 and len(transcription) < 12:
word = transcription
if not(zh_pattern.search(word) and ('日' not in word or '月' not in word)):
flag = True
res.append(item)
save_pred = file_name + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
if flag ==True:
f_out.write(save_pred)
f_out.close()
```
然后将有效预测对应的图像区域和标签提取出来,构建训练集。具体实现脚本如下:
```python
import cv2
import json
import numpy as np
PATH = '/home/aistudio/PaddleOCR/inference_results/' #数据原始路径
SAVE_PATH = '/home/aistudio/mining_images/' #裁剪后数据保存路径
file_list = '/home/aistudio/PaddleOCR/selected_results.txt' #数据预测结果
label_file = '/home/aistudio/mining_images/mining_train.list' #输出真实数据训练集标签list
if not os.path.exists(SAVE_PATH):
os.mkdir(SAVE_PATH)
f_label = open(label_file, 'w')
def get_rotate_crop_image(img, points):
"""
根据检测结果points,从输入图像img中裁剪出相应的区域
"""
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
# 形变或倾斜,会做透视变换,reshape成矩形
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def crop_and_get_filelist(file_list):
with open(file_list, "r", encoding='utf-8') as fin:
lines = fin.readlines()
img_num = 0
for line in lines:
img_name, json_file = line.strip().split('\t')
preds = json.loads(json_file)
for item in preds:
transcription = item['transcription']
points = item['points']
points = np.array(points).astype('float32')
#print('processing {}...'.format(img_name))
img = cv2.imread(PATH+img_name)
dst_img = get_rotate_crop_image(img, points)
h, w, c = dst_img.shape
newWidth = int((32. / h) * w)
newImg = cv2.resize(dst_img, (newWidth, 32))
new_img_name = '{:05d}.jpg'.format(img_num)
cv2.imwrite(SAVE_PATH+new_img_name, dst_img)
f_label.write(SAVE_PATH+new_img_name+'\t'+transcription+'\n')
img_num += 1
crop_and_get_filelist(file_list)
f_label.close()
```
### 6.3 模型训练
通过数据挖掘,我们得到了真实场景数据和对应的标签。接下来使用真实数据finetune,观察精度提升效果。
利用真实数据进行finetune:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=20 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/mining_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改:
```txt
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
```
示例使用我们提供的真实数据进行finetune,如想换成自己的数据,只需要相应的修改`Train.dataset.data_dir``Train.dataset.label_file_list`参数即可。
由于数据量不大,这里仅训练20个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**71.33%**
由于数量比较少,精度会比合成数据finetue的略低。
## 7. 基于合成+真实数据finetune
为进一步提升模型精度,我们结合使用合成数据和挖掘到的真实数据进行finetune。
利用合成+真实数据进行finetune,各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改:
```txt
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
```
生成训练list文件:
```bash
# 生成训练集文件list
cat /home/aistudio/PaddleOCR/data/render_train.list /home/aistudio/PaddleOCR/data/mining_train.list > /home/aistudio/PaddleOCR/data/render_mining_train.list
```
启动训练:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=40 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/render_mining_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
示例使用我们提供的真实+合成数据进行finetune,如想换成自己的数据,只需要相应的修改Train.dataset.data_dir和Train.dataset.label_file_list参数即可。
由于数据量不大,这里仅训练40个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**86.99%**
可以看到,相较于原始PP-OCRv3的识别精度62.99%,使用合成数据+真实数据finetune后,识别精度能提升24%。
如需获取已训练模型,可以同样扫描上方二维码下载,将下载或训练完成的模型放置在对应目录下即可完成模型推理。
模型的推理部署方法可以参考repo文档: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/deploy/README_ch.md
# 基于PP-OCRv3的液晶屏读数识别
- [1. 项目背景及意义](#1-项目背景及意义)
- [2. 项目内容](#2-项目内容)
- [3. 安装环境](#3-安装环境)
- [4. 文字检测](#4-文字检测)
- [4.1 PP-OCRv3检测算法介绍](#41-PP-OCRv3检测算法介绍)
- [4.2 数据准备](#42-数据准备)
- [4.3 模型训练](#43-模型训练)
- [4.3.1 预训练模型直接评估](#431-预训练模型直接评估)
- [4.3.2 预训练模型直接finetune](#432-预训练模型直接finetune)
- [4.3.3 基于预训练模型Finetune_student模型](#433-基于预训练模型Finetune_student模型)
- [4.3.4 基于预训练模型Finetune_teacher模型](#434-基于预训练模型Finetune_teacher模型)
- [4.3.5 采用CML蒸馏进一步提升student模型精度](#435-采用CML蒸馏进一步提升student模型精度)
- [4.3.6 模型导出推理](#436-4.3.6-模型导出推理)
- [5. 文字识别](#5-文字识别)
- [5.1 PP-OCRv3识别算法介绍](#51-PP-OCRv3识别算法介绍)
- [5.2 数据准备](#52-数据准备)
- [5.3 模型训练](#53-模型训练)
- [5.4 模型导出推理](#54-模型导出推理)
- [6. 系统串联](#6-系统串联)
- [6.1 后处理](#61-后处理)
- [7. PaddleServing部署](#7-PaddleServing部署)
## 1. 项目背景及意义
目前光学字符识别(OCR)技术在我们的生活当中被广泛使用,但是大多数模型在通用场景下的准确性还有待提高,针对于此我们借助飞桨提供的PaddleOCR套件较容易的实现了在垂类场景下的应用。
该项目以国家质量基础(NQI)为准绳,充分利用大数据、云计算、物联网等高新技术,构建覆盖计量端、实验室端、数据端和硬件端的完整计量解决方案,解决传统计量校准中存在的难题,拓宽计量检测服务体系和服务领域;解决无数传接口或数传接口不统一、不公开的计量设备,以及计量设备所处的环境比较恶劣,不适合人工读取数据。通过OCR技术实现远程计量,引领计量行业向智慧计量转型和发展。
## 2. 项目内容
本项目基于PaddleOCR开源套件,以PP-OCRv3检测和识别模型为基础,针对液晶屏读数识别场景进行优化。
Aistudio项目链接:[OCR液晶屏读数识别](https://aistudio.baidu.com/aistudio/projectdetail/4080130)
## 3. 安装环境
```python
# 首先git官方的PaddleOCR项目,安装需要的依赖
# 第一次运行打开该注释
# git clone https://gitee.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR
pip install -r requirements.txt
```
## 4. 文字检测
文本检测的任务是定位出输入图像中的文字区域。近年来学术界关于文本检测的研究非常丰富,一类方法将文本检测视为目标检测中的一个特定场景,基于通用目标检测算法进行改进适配,如TextBoxes[1]基于一阶段目标检测器SSD[2]算法,调整目标框使之适合极端长宽比的文本行,CTPN[3]则是基于Faster RCNN[4]架构改进而来。但是文本检测与目标检测在目标信息以及任务本身上仍存在一些区别,如文本一般长宽比较大,往往呈“条状”,文本行之间可能比较密集,弯曲文本等,因此又衍生了很多专用于文本检测的算法。本项目基于PP-OCRv3算法进行优化。
### 4.1 PP-OCRv3检测算法介绍
PP-OCRv3检测模型是对PP-OCRv2中的CML(Collaborative Mutual Learning) 协同互学习文本检测蒸馏策略进行了升级。如下图所示,CML的核心思想结合了①传统的Teacher指导Student的标准蒸馏与 ②Students网络之间的DML互学习,可以让Students网络互学习的同时,Teacher网络予以指导。PP-OCRv3分别针对教师模型和学生模型进行进一步效果优化。其中,在对教师模型优化时,提出了大感受野的PAN结构LK-PAN和引入了DML(Deep Mutual Learning)蒸馏策略;在对学生模型优化时,提出了残差注意力机制的FPN结构RSE-FPN。
![](https://ai-studio-static-online.cdn.bcebos.com/c306b2f028364805a55494d435ab553a76cf5ae5dd3f4649a948ea9aeaeb28b8)
详细优化策略描述请参考[PP-OCRv3优化策略](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md#2)
### 4.2 数据准备
[计量设备屏幕字符检测数据集](https://aistudio.baidu.com/aistudio/datasetdetail/127845)数据来源于实际项目中各种计量设备的数显屏,以及在网上搜集的一些其他数显屏,包含训练集755张,测试集355张。
```python
# 在PaddleOCR下创建新的文件夹train_data
mkdir train_data
# 下载数据集并解压到指定路径下
unzip icdar2015.zip -d train_data
```
```python
# 随机查看文字检测数据集图片
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
train = './train_data/icdar2015/text_localization/test'
# 从指定目录中选取一张图片
def get_one_image(train):
plt.figure()
files = os.listdir(train)
n = len(files)
ind = np.random.randint(0,n)
img_dir = os.path.join(train,files[ind])
image = Image.open(img_dir)
plt.imshow(image)
plt.show()
image = image.resize([208, 208])
get_one_image(train)
```
![det_png](https://ai-studio-static-online.cdn.bcebos.com/0639da09b774458096ae577e82b2c59e89ced6a00f55458f946997ab7472a4f8)
### 4.3 模型训练
#### 4.3.1 预训练模型直接评估
下载我们需要的PP-OCRv3检测预训练模型,更多选择请自行选择其他的[文字检测模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md#1-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E6%A8%A1%E5%9E%8B)
```python
#使用该指令下载需要的预训练模型
wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar
# 解压预训练模型文件
tar -xf ./pretrained_models/ch_PP-OCRv3_det_distill_train.tar -C pretrained_models
```
在训练之前,我们可以直接使用下面命令来评估预训练模型的效果:
```python
# 评估预训练模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy"
```
结果如下:
| | 方案 |hmeans|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.5%|
#### 4.3.2 预训练模型直接finetune
##### 修改配置文件
我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方:
```
epoch:100
save_epoch_step:10
eval_batch_step:[0, 50]
save_model_dir: ./output/ch_PP-OCR_v3_det/
pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy
learning_rate: 0.00025
num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错
```
##### 开始训练
使用我们上面修改的配置文件configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml,训练命令如下:
```python
# 开始训练模型
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy
```
评估训练好的模型:
```python
# 评估训练好的模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det/best_accuracy"
```
结果如下:
| | 方案 |hmeans|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.5%|
| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.2%|
#### 4.3.3 基于预训练模型Finetune_student模型
我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方:
```
epoch:100
save_epoch_step:10
eval_batch_step:[0, 50]
save_model_dir: ./output/ch_PP-OCR_v3_det_student/
pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/student
learning_rate: 0.00025
num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错
```
训练命令如下:
```python
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/student
```
评估训练好的模型:
```python
# 评估训练好的模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_student/best_accuracy"
```
结果如下:
| | 方案 |hmeans|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.5%|
| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.2%|
| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.0%|
#### 4.3.4 基于预训练模型Finetune_teacher模型
首先需要从提供的预训练模型best_accuracy.pdparams中提取teacher参数,组合成适合dml训练的初始化模型,提取代码如下:
```python
cd ./pretrained_models/
# transform teacher params in best_accuracy.pdparams into teacher_dml.paramers
import paddle
# load pretrained model
all_params = paddle.load("ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# print(all_params.keys())
# keep teacher params
t_params = {key[len("Teacher."):]: all_params[key] for key in all_params if "Teacher." in key}
# print(t_params.keys())
s_params = {"Student." + key: t_params[key] for key in t_params}
s2_params = {"Student2." + key: t_params[key] for key in t_params}
s_params = {**s_params, **s2_params}
# print(s_params.keys())
paddle.save(s_params, "ch_PP-OCRv3_det_distill_train/teacher_dml.pdparams")
```
我们使用configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方:
```
epoch:100
save_epoch_step:10
eval_batch_step:[0, 50]
save_model_dir: ./output/ch_PP-OCR_v3_det_teacher/
pretrained_model: ./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml
learning_rate: 0.00025
num_workers: 0 # 如果单卡训练,建议将Train和Eval的loader部分的num_workers设置为0,否则会出现`/dev/shm insufficient`的报错
```
训练命令如下:
```python
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_dml
```
评估训练好的模型:
```python
# 评估训练好的模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_dml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_teacher/best_accuracy"
```
结果如下:
| | 方案 |hmeans|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.5%|
| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.2%|
| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.0%|
| 3 | PP-OCRv3中英文超轻量检测预训练模型fintune教师模型 |84.8%|
#### 4.3.5 采用CML蒸馏进一步提升student模型精度
需要从4.3.3和4.3.4训练得到的best_accuracy.pdparams中提取各自代表student和teacher的参数,组合成适合cml训练的初始化模型,提取代码如下:
```python
# transform teacher params and student parameters into cml model
import paddle
all_params = paddle.load("./pretrained_models/ch_PP-OCRv3_det_distill_train/best_accuracy.pdparams")
# print(all_params.keys())
t_params = paddle.load("./output/ch_PP-OCR_v3_det_teacher/best_accuracy.pdparams")
# print(t_params.keys())
s_params = paddle.load("./output/ch_PP-OCR_v3_det_student/best_accuracy.pdparams")
# print(s_params.keys())
for key in all_params:
# teacher is OK
if "Teacher." in key:
new_key = key.replace("Teacher", "Student")
#print("{} >> {}\n".format(key, new_key))
assert all_params[key].shape == t_params[new_key].shape
all_params[key] = t_params[new_key]
if "Student." in key:
new_key = key.replace("Student.", "")
#print("{} >> {}\n".format(key, new_key))
assert all_params[key].shape == s_params[new_key].shape
all_params[key] = s_params[new_key]
if "Student2." in key:
new_key = key.replace("Student2.", "")
print("{} >> {}\n".format(key, new_key))
assert all_params[key].shape == s_params[new_key].shape
all_params[key] = s_params[new_key]
paddle.save(all_params, "./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student.pdparams")
```
训练命令如下:
```python
python tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model=./pretrained_models/ch_PP-OCRv3_det_distill_train/teacher_cml_student Global.save_model_dir=./output/ch_PP-OCR_v3_det_finetune/
```
评估训练好的模型:
```python
# 评估训练好的模型
python tools/eval.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_det_finetune/best_accuracy"
```
结果如下:
| | 方案 |hmeans|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量检测预训练模型直接预测 |47.5%|
| 1 | PP-OCRv3中英文超轻量检测预训练模型fintune |65.2%|
| 2 | PP-OCRv3中英文超轻量检测预训练模型fintune学生模型 |80.0%|
| 3 | PP-OCRv3中英文超轻量检测预训练模型fintune教师模型 |84.8%|
| 4 | 基于2和3训练好的模型fintune |82.7%|
如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<div align="left">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
将下载或训练完成的模型放置在对应目录下即可完成模型推理。
#### 4.3.6 模型导出推理
训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
##### 4.3.6.1 模型导出
导出命令如下:
```python
# 转化为推理模型
python tools/export_model.py \
-c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml \
-o Global.pretrained_model=./output/ch_PP-OCR_v3_det_finetune/best_accuracy \
-o Global.save_inference_dir="./inference/det_ppocrv3"
```
##### 4.3.6.2 模型推理
导出模型后,可以使用如下命令进行推理预测:
```python
# 推理预测
python tools/infer/predict_det.py --image_dir="train_data/icdar2015/text_localization/test/1.jpg" --det_model_dir="./inference/det_ppocrv3/Student"
```
## 5. 文字识别
文本识别的任务是识别出图像中的文字内容,一般输入来自于文本检测得到的文本框截取出的图像文字区域。文本识别一般可以根据待识别文本形状分为规则文本识别和不规则文本识别两大类。规则文本主要指印刷字体、扫描文本等,文本大致处在水平线位置;不规则文本往往不在水平位置,存在弯曲、遮挡、模糊等问题。不规则文本场景具有很大的挑战性,也是目前文本识别领域的主要研究方向。本项目基于PP-OCRv3算法进行优化。
### 5.1 PP-OCRv3识别算法介绍
PP-OCRv3的识别模块是基于文本识别算法[SVTR](https://arxiv.org/abs/2205.00159)优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力。如下图所示,PP-OCRv3采用了6个优化策略。
![](https://ai-studio-static-online.cdn.bcebos.com/d4f5344b5b854d50be738671598a89a45689c6704c4d481fb904dd7cf72f2a1a)
优化策略汇总如下:
* SVTR_LCNet:轻量级文本识别网络
* GTC:Attention指导CTC训练策略
* TextConAug:挖掘文字上下文信息的数据增广策略
* TextRotNet:自监督的预训练模型
* UDML:联合互学习策略
* UIM:无标注数据挖掘方案
详细优化策略描述请参考[PP-OCRv3优化策略](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/PP-OCRv3_introduction.md#3-%E8%AF%86%E5%88%AB%E4%BC%98%E5%8C%96)
### 5.2 数据准备
[计量设备屏幕字符识别数据集](https://aistudio.baidu.com/aistudio/datasetdetail/128714)数据来源于实际项目中各种计量设备的数显屏,以及在网上搜集的一些其他数显屏,包含训练集19912张,测试集4099张。
```python
# 解压下载的数据集到指定路径下
unzip ic15_data.zip -d train_data
```
```python
# 随机查看文字检测数据集图片
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
train = './train_data/ic15_data/train'
# 从指定目录中选取一张图片
def get_one_image(train):
plt.figure()
files = os.listdir(train)
n = len(files)
ind = np.random.randint(0,n)
img_dir = os.path.join(train,files[ind])
image = Image.open(img_dir)
plt.imshow(image)
plt.show()
image = image.resize([208, 208])
get_one_image(train)
```
![rec_png](https://ai-studio-static-online.cdn.bcebos.com/3de0d475c69746d0a184029001ef07c85fd68816d66d4beaa10e6ef60030f9b4)
### 5.3 模型训练
#### 下载预训练模型
下载我们需要的PP-OCRv3识别预训练模型,更多选择请自行选择其他的[文字识别模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)
```python
# 使用该指令下载需要的预训练模型
wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
# 解压预训练模型文件
tar -xf ./pretrained_models/ch_PP-OCRv3_rec_train.tar -C pretrained_models
```
#### 修改配置文件
我们使用configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,主要修改训练轮数和学习率参相关参数,设置预训练模型路径,设置数据集路径。 另外,batch_size可根据自己机器显存大小进行调整。 具体修改如下几个地方:
```
epoch_num: 100 # 训练epoch数
save_model_dir: ./output/ch_PP-OCR_v3_rec
save_epoch_step: 10
eval_batch_step: [0, 100] # 评估间隔,每隔100step评估一次
cal_metric_during_train: true
pretrained_model: ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy # 预训练模型路径
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
use_space_char: true # 使用空格
lr:
name: Cosine # 修改学习率衰减策略为Cosine
learning_rate: 0.0002 # 修改fine-tune的学习率
warmup_epoch: 2 # 修改warmup轮数
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/ # 训练集图片路径
ext_op_transform_idx: 1
label_file_list:
- ./train_data/ic15_data/rec_gt_train.txt # 训练集标签
ratio_list:
- 1.0
loader:
shuffle: true
batch_size_per_card: 64
drop_last: true
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/ # 测试集图片路径
label_file_list:
- ./train_data/ic15_data/rec_gt_test.txt # 测试集标签
ratio_list:
- 1.0
loader:
shuffle: false
drop_last: false
batch_size_per_card: 64
num_workers: 4
```
在训练之前,我们可以直接使用下面命令来评估预训练模型的效果:
```python
# 评估预训练模型
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy"
```
结果如下:
| | 方案 |accuracy|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量识别预训练模型直接预测 |70.4%|
#### 开始训练
我们使用上面修改好的配置文件configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,预训练模型,数据集路径,学习率,训练轮数等都已经设置完毕后,可以使用下面命令开始训练。
```python
# 开始训练识别模型
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
```
训练完成后,可以对训练模型中最好的进行测试,评估命令如下:
```python
# 评估finetune效果
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.checkpoints="./output/ch_PP-OCR_v3_rec/best_accuracy"
```
结果如下:
| | 方案 |accuracy|
|---|---------------------------|---|
| 0 | PP-OCRv3中英文超轻量识别预训练模型直接预测 |70.4%|
| 1 | PP-OCRv3中英文超轻量识别预训练模型finetune |82.2%|
如需获取已训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<div align="left">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
将下载或训练完成的模型放置在对应目录下即可完成模型推理。
### 5.4 模型导出推理
训练完成后,可以将训练模型转换成inference模型。inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
#### 模型导出
导出命令如下:
```python
# 转化为推理模型
python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" Global.save_inference_dir="./inference/rec_ppocrv3/"
```
#### 模型推理
导出模型后,可以使用如下命令进行推理预测
```python
# 推理预测
python tools/infer/predict_rec.py --image_dir="train_data/ic15_data/test/1_crop_0.jpg" --rec_model_dir="./inference/rec_ppocrv3/Student"
```
## 6. 系统串联
我们将上面训练好的检测和识别模型进行系统串联测试,命令如下:
```python
#串联测试
python3 tools/infer/predict_system.py --image_dir="./train_data/icdar2015/text_localization/test/142.jpg" --det_model_dir="./inference/det_ppocrv3/Student" --rec_model_dir="./inference/rec_ppocrv3/Student"
```
测试结果保存在`./inference_results/`目录下,可以用下面代码进行可视化
```python
%cd /home/aistudio/PaddleOCR
# 显示结果
import matplotlib.pyplot as plt
from PIL import Image
img_path= "./inference_results/142.jpg"
img = Image.open(img_path)
plt.figure("test_img", figsize=(30,30))
plt.imshow(img)
plt.show()
```
![sys_res_png](https://ai-studio-static-online.cdn.bcebos.com/901ab741cb46441ebec510b37e63b9d8d1b7c95f63cc4e5e8757f35179ae6373)
### 6.1 后处理
如果需要获取key-value信息,可以基于启发式的规则,将识别结果与关键字库进行匹配;如果匹配上了,则取该字段为key, 后面一个字段为value。
```python
def postprocess(rec_res):
keys = ["型号", "厂家", "版本号", "检定校准分类", "计量器具编号", "烟尘流量",
"累积体积", "烟气温度", "动压", "静压", "时间", "试验台编号", "预测流速",
"全压", "烟温", "流速", "工况流量", "标杆流量", "烟尘直读嘴", "烟尘采样嘴",
"大气压", "计前温度", "计前压力", "干球温度", "湿球温度", "流量", "含湿量"]
key_value = []
if len(rec_res) > 1:
for i in range(len(rec_res) - 1):
rec_str, _ = rec_res[i]
for key in keys:
if rec_str in key:
key_value.append([rec_str, rec_res[i + 1][0]])
break
return key_value
key_value = postprocess(filter_rec_res)
```
## 7. PaddleServing部署
首先需要安装PaddleServing部署相关的环境
```python
python -m pip install paddle-serving-server-gpu
python -m pip install paddle_serving_client
python -m pip install paddle-serving-app
```
### 7.1 转化检测模型
```python
cd deploy/pdserving/
python -m paddle_serving_client.convert --dirname ../../inference/det_ppocrv3/Student/ \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./ppocr_det_v3_serving/ \
--serving_client ./ppocr_det_v3_client/
```
### 7.2 转化识别模型
```python
python -m paddle_serving_client.convert --dirname ../../inference/rec_ppocrv3/Student \
--model_filename inference.pdmodel \
--params_filename inference.pdiparams \
--serving_server ./ppocr_rec_v3_serving/ \
--serving_client ./ppocr_rec_v3_client/
```
### 7.3 启动服务
首先可以将后处理代码加入到web_service.py中,具体修改如下:
```
# 代码153行后面增加下面代码
def _postprocess(rec_res):
keys = ["型号", "厂家", "版本号", "检定校准分类", "计量器具编号", "烟尘流量",
"累积体积", "烟气温度", "动压", "静压", "时间", "试验台编号", "预测流速",
"全压", "烟温", "流速", "工况流量", "标杆流量", "烟尘直读嘴", "烟尘采样嘴",
"大气压", "计前温度", "计前压力", "干球温度", "湿球温度", "流量", "含湿量"]
key_value = []
if len(rec_res) > 1:
for i in range(len(rec_res) - 1):
rec_str, _ = rec_res[i]
for key in keys:
if rec_str in key:
key_value.append([rec_str, rec_res[i + 1][0]])
break
return key_value
key_value = _postprocess(rec_list)
res = {"result": str(key_value)}
# res = {"result": str(result_list)}
```
启动服务端
```python
python web_service.py 2>&1 >log.txt
```
### 7.4 发送请求
然后再开启一个新的终端,运行下面的客户端代码
```python
python pipeline_http_client.py --image_dir ../../train_data/icdar2015/text_localization/test/142.jpg
```
可以获取到最终的key-value结果:
```
大气压, 100.07kPa
干球温度, 0000℃
计前温度, 0000℃
湿球温度, 0000℃
计前压力, -0000kPa
流量, 00.0L/min
静压, 00000kPa
含湿量, 00.0 %
```
......@@ -28,7 +28,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 18
Neck:
name: DBFPN
......
......@@ -45,7 +45,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 18
Neck:
name: DBFPN
......
......@@ -65,7 +65,7 @@ Loss:
- ["Student", "Teacher"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
# act: None
model_name_pairs: ["Student", "Teacher"]
key: maps
- DistillationDBLoss:
......
......@@ -61,7 +61,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
name: ResNet
name: ResNet_vd
in_channels: 3
layers: 50
Neck:
......
......@@ -25,7 +25,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
name: ResNet
name: ResNet_vd
in_channels: 3
layers: 50
Neck:
......@@ -40,7 +40,7 @@ Architecture:
model_type: det
algorithm: DB
Backbone:
name: ResNet
name: ResNet_vd
in_channels: 3
layers: 50
Neck:
......@@ -60,7 +60,7 @@ Loss:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
......
......@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 18
disable_se: True
Neck:
......
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_icdar15/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 1152
- 2048
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_td_tr/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/train_gt_labels.txt
- ./train_data/TD_TR/TR400/gt_labels.txt
ratio_list:
- 1.0
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/test_gt_labels.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 736
- 736
keep_ratio: True
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
......@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 50
Neck:
name: DBFPN
......
......@@ -21,7 +21,7 @@ Architecture:
algorithm: FCE
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 50
dcn_stage: [False, True, True, True]
out_indices: [1,2,3]
......
......@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 50
Neck:
name: EASTFPN
......
......@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 50
Neck:
name: FPN
......
......@@ -20,7 +20,7 @@ Architecture:
algorithm: DB
Transform:
Backbone:
name: ResNet
name: ResNet_vd
layers: 18
disable_se: True
Neck:
......
......@@ -17,7 +17,7 @@ Global:
checkpoints:
save_inference_dir:
use_visualdl: False
class_path: ./train_data/wildreceipt/class_list.txt
class_path: &class_path ./train_data/wildreceipt/class_list.txt
infer_img: ./train_data/wildreceipt/1.txt
save_res_path: ./output/sdmgr_kie/predicts_kie.txt
img_scale: [ 1024, 512 ]
......@@ -72,6 +72,7 @@ Train:
order: 'hwc'
- KieLabelEncode: # Class handling label
character_dict_path: ./train_data/wildreceipt/dict.txt
class_path: *class_path
- KieResize:
- ToCHWImage:
- KeepKeys:
......@@ -88,7 +89,6 @@ Eval:
data_dir: ./train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
# - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
......
......@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
data_dir: ./train_data/data_lmdb_release/evaluaiton/
transforms:
- DecodeImage: # load image
img_mode: BGR
......
......@@ -8,7 +8,7 @@ Global:
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
pretrained_model: ./pretrain_models/abinet_vl_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
......@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: RGB
......
......@@ -77,7 +77,7 @@ Metric:
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
img_mode: BGR
......@@ -97,7 +97,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
......
......@@ -81,7 +81,7 @@ Train:
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
data_dir: ./train_data/data_lmdb_release/evaluation/
transforms:
- DecodeImage: # load image
img_mode: BGR
......
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
......@@ -4,21 +4,20 @@ Global:
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 3
save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
checkpoints:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/table/table.jpg
infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
infer_mode: False
process_total_num: 0
process_cut_num: 0
......@@ -44,11 +43,8 @@ Architecture:
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
Loss:
name: TableAttentionLoss
......@@ -61,28 +57,34 @@ PostProcess:
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 32
......@@ -92,24 +94,29 @@ Train:
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/FUNSD/class_list.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
......@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/
save_model_dir: ./output/re_layoutlmv2_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
Architecture:
model_type: vqa
......@@ -21,7 +21,7 @@ Architecture:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
checkpoints:
Loss:
name: LossFromOutput
......@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
......@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
......@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train_v4.json
# - ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ./train_data/FUNSD/class_list.txt
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 16
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test_v4.json
# - ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
......@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
......@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
......@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
......@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlm_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
......@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm/
save_model_dir: ./output/ser_layoutlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutlm_xfund_zh/res/
Architecture:
model_type: vqa
......@@ -43,7 +43,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
......@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
- train_data/XFUND/zh_train/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -89,7 +89,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -112,7 +112,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 100 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlmv2_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
......@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/
save_model_dir: ./output/ser_layoutlmv2_xfund_zh/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
......@@ -11,8 +11,8 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutlmv2_xfund_zh/res/
Architecture:
model_type: vqa
......@@ -44,7 +44,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
......@@ -55,7 +55,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
- train_data/XFUND/zh_train/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: output/ser_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: res_img_aug_with_gt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 100
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_wildreceipt
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg
save_res_path: ./output/ser_layoutxlm_wildreceipt/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 51
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/wildreceipt/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/wildreceipt/
label_file_list:
- ./train_data/wildreceipt/wildreceipt_train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
......@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm/
save_model_dir: ./output/ser_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
......@@ -11,8 +11,8 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutxlm_xfund_zh/res
Architecture:
model_type: vqa
......@@ -43,7 +43,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
......@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
......@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
......@@ -15,20 +15,24 @@
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
Running PaddleOCR text recognition model via TVM on bare metal Arm(R) Cortex(R)-M55 CPU and CMSIS-NN
===============================================================
English | [简体中文](README_ch.md)
This folder contains an example of how to use TVM to run a PaddleOCR model
on bare metal Cortex(R)-M55 CPU and CMSIS-NN.
Running PaddleOCR text recognition model on bare metal Arm(R) Cortex(R)-M55 CPU using Arm Virtual Hardware
======================================================================
Prerequisites
This folder contains an example of how to run a PaddleOCR model on bare metal [Cortex(R)-M55 CPU](https://www.arm.com/products/silicon-ip-cpu/cortex-m/cortex-m55) using [Arm Virtual Hardware](https://www.arm.com/products/development-tools/simulation/virtual-hardware).
Running environment and prerequisites
-------------
If the demo is run in the ci_cpu Docker container provided with TVM, then the following
software will already be installed.
Case 1: If the demo is run in Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by [AWS](https://aws.amazon.com/marketplace/pp/prodview-urbpq7yo5va7g?sr=0-1&ref_=beagle&applicationId=AWSMPContessa)/[AWS China](https://awsmarketplace.amazonaws.cn/marketplace/pp/prodview-2y7nefntbmybu), the following software will be installed through [configure_avh.sh](./configure_avh.sh) script. It will install automatically when you run the application through [run_demo.sh](./run_demo.sh) script.
You can refer to this [guide](https://arm-software.github.io/AVH/main/examples/html/MicroSpeech.html#amilaunch) to launch an Arm Virtual Hardware AMI instance.
Case 2: If the demo is run in the [ci_cpu Docker container](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_cpu) provided with [TVM](https://github.com/apache/tvm), then the following software will already be installed.
If the demo is not run in the ci_cpu Docker container, then you will need the following:
Case 3: If the demo is not run in the ci_cpu Docker container, then you will need the following:
- Software required to build and run the demo (These can all be installed by running
https://github.com/apache/tvm/blob/main/docker/install/ubuntu_install_ethosu_driver_stack.sh .)
tvm/docker/install/ubuntu_install_ethosu_driver_stack.sh.)
- [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps)
- [cmake 3.19.5](https://github.com/Kitware/CMake/releases/)
- [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2)
......@@ -40,19 +44,22 @@ If the demo is not run in the ci_cpu Docker container, then you will need the fo
pip install -r ./requirements.txt
```
In case2 and case3:
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
You will also need TVM which can either be:
- Installed from TLCPack(see [TLCPack](https://tlcpack.ai/))
- Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html))
- When building from source, the following need to be set in config.cmake:
- set(USE_CMSISNN ON)
- set(USE_MICRO ON)
- set(USE_LLVM ON)
- Installed from TLCPack nightly(see [TLCPack](https://tlcpack.ai/))
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
Running the demo application
----------------------------
......@@ -62,6 +69,12 @@ Type the following command to run the bare metal text recognition application ([
./run_demo.sh
```
If you are not able to use Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by AWS/AWS China, specify argument --enable_FVP to 1 to make the application run on local Fixed Virtual Platforms (FVPs) executables.
```bash
./run_demo.sh --enable_FVP 1
```
If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then
the locations for these can be specified as arguments to run_demo.sh, for example:
......@@ -70,13 +83,14 @@ the locations for these can be specified as arguments to run_demo.sh, for exampl
--ethosu_platform_path /home/tvm-user/ethosu/core_platform
```
This will:
With [run_demo.sh](./run_demo.sh) to run the demo application, it will:
- Set up running environment by installing the required prerequisites automatically if running in Arm Virtual Hardware Amazon AMI instance(not specify --enable_FVP to 1)
- Download a PaddleOCR text recognition model
- Use tvmc to compile the text recognition model for Cortex(R)-M55 CPU and CMSIS-NN
- Create a C header file inputs.c containing the image data as a C array
- Create a C header file outputs.c containing a C array where the output of inference will be stored
- Build the demo application
- Run the demo application on a Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software
- Run the demo application on a Arm Virtual Hardware based on Arm(R) Corstone(TM)-300 software
- The application will report the text on the image and the corresponding score.
Using your own image
......@@ -92,9 +106,11 @@ python3 ./convert_image.py path/to/image
Model description
-----------------
In this demo, the model we use is an English recognition model based on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md). PP-OCRv3 is the third version of the PP-OCR series model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). This series of models has the following features:
The example is built on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md) English recognition model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). Since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current 2.7M English recognition model.
PP-OCRv3 is the third version of the PP-OCR series model. This series of models has the following features:
- PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M
- Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details
- Support vertical text recognition, and long text recognition
The text recognition model in PP-OCRv3 supports more than 80 languages. In the process of model development, since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current model.
\ No newline at end of file
......@@ -14,6 +14,7 @@
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
[English](README.md) | 简体中文
通过TVM在 Arm(R) Cortex(R)-M55 CPU 上运行 PaddleOCR文 本能识别模型
===============================================================
......@@ -85,9 +86,9 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/
模型描述
-----------------
在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。 PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本。 该系列模型具有以下特点:
在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3原始文本识别模型的基础上进行适配,最终模型大小为2.7M。
PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本,该系列模型具有以下特点:
- 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。
- 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。
- 支持竖排文本识别,长文本识别。
PP-OCRv3 中的文本识别模型支持 80 多种语言。 在模型开发过程中,由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3文本识别模型的基础上删除了不支持的算子,得到当前模型。
\ No newline at end of file
#!/bin/bash
# Copyright (c) 2022 Arm Limited and Contributors. All rights reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
set -o pipefail
# Show usage
function show_usage() {
cat <<EOF
Usage: Set up running environment by installing the required prerequisites.
-h, --help
Display this help message.
EOF
}
if [ "$#" -eq 1 ] && [ "$1" == "--help" -o "$1" == "-h" ]; then
show_usage
exit 0
elif [ "$#" -ge 1 ]; then
show_usage
exit 1
fi
echo -e "\e[36mStart setting up running environment\e[0m"
# Install CMSIS
echo -e "\e[36mStart installing CMSIS\e[0m"
CMSIS_PATH="/opt/arm/ethosu/cmsis"
mkdir -p "${CMSIS_PATH}"
CMSIS_SHA="977abe9849781a2e788b02282986480ff4e25ea6"
CMSIS_SHASUM="86c88d9341439fbb78664f11f3f25bc9fda3cd7de89359324019a4d87d169939eea85b7fdbfa6ad03aa428c6b515ef2f8cd52299ce1959a5444d4ac305f934cc"
CMSIS_URL="http://github.com/ARM-software/CMSIS_5/archive/${CMSIS_SHA}.tar.gz"
DOWNLOAD_PATH="/tmp/${CMSIS_SHA}.tar.gz"
wget ${CMSIS_URL} -O "${DOWNLOAD_PATH}"
echo "$CMSIS_SHASUM" ${DOWNLOAD_PATH} | sha512sum -c
tar -xf "${DOWNLOAD_PATH}" -C "${CMSIS_PATH}" --strip-components=1
touch "${CMSIS_PATH}"/"${CMSIS_SHA}".sha
echo -e "\e[36mCMSIS Installation SUCCESS\e[0m"
# Install Arm(R) Ethos(TM)-U NPU driver stack
echo -e "\e[36mStart installing Arm(R) Ethos(TM)-U NPU driver stack\e[0m"
git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform" /opt/arm/ethosu/core_platform
cd /opt/arm/ethosu/core_platform
git checkout tags/"21.11"
echo -e "\e[36mArm(R) Ethos(TM)-U Core Platform Installation SUCCESS\e[0m"
# Install Arm(R) GNU Toolchain
echo -e "\e[36mStart installing Arm(R) GNU Toolchain\e[0m"
mkdir -p /opt/arm/gcc-arm-none-eabi
export gcc_arm_url='https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2?revision=ca0cbf9c-9de2-491c-ac48-898b5bbc0443&la=en&hash=68760A8AE66026BCF99F05AC017A6A50C6FD832A'
curl --retry 64 -sSL ${gcc_arm_url} | tar -C /opt/arm/gcc-arm-none-eabi --strip-components=1 -jx
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
arm-none-eabi-gcc --version
arm-none-eabi-g++ --version
echo -e "\e[36mArm(R) Arm(R) GNU Toolchain Installation SUCCESS\e[0m"
# Install TVM from TLCPack
echo -e "\e[36mStart installing TVM\e[0m"
pip install tlcpack-nightly -f https://tlcpack.ai/wheels
echo -e "\e[36mTVM Installation SUCCESS\e[0m"
\ No newline at end of file
......@@ -24,18 +24,28 @@ function show_usage() {
cat <<EOF
Usage: run_demo.sh
-h, --help
Display this help message.
Display this help message.
--cmsis_path CMSIS_PATH
Set path to CMSIS.
Set path to CMSIS.
--ethosu_platform_path ETHOSU_PLATFORM_PATH
Set path to Arm(R) Ethos(TM)-U core platform.
Set path to Arm(R) Ethos(TM)-U core platform.
--fvp_path FVP_PATH
Set path to FVP.
Set path to FVP.
--cmake_path
Set path to cmake.
Set path to cmake.
--enable_FVP
Set 1 to run application on local Fixed Virtual Platforms (FVPs) executables.
EOF
}
# Configure environment variables
FVP_enable=0
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
# Install python libraries
echo -e "\e[36mInstall python libraries\e[0m"
sudo pip install -r ./requirements.txt
# Parse arguments
while (( $# )); do
case "$1" in
......@@ -91,6 +101,18 @@ while (( $# )); do
exit 1
fi
;;
--enable_FVP)
if [ $# -gt 1 ] && [ "$2" == "1" -o "$2" == "0" ];
then
FVP_enable="$2"
shift 2
else
echo 'ERROR: --enable_FVP requires a right argument 1 or 0' >&2
show_usage >&2
exit 1
fi
;;
-*|--*)
echo "Error: Unknown flag: $1" >&2
......@@ -100,17 +122,27 @@ while (( $# )); do
esac
done
# Choose running environment: cloud(default) or local environment
Platform="VHT_Corstone_SSE-300_Ethos-U55"
if [ $FVP_enable == "1" ]; then
Platform="FVP_Corstone_SSE-300_Ethos-U55"
echo -e "\e[36mRun application on local Fixed Virtual Platforms (FVPs)\e[0m"
else
if [ ! -d "/opt/arm/" ]; then
sudo ./configure_avh.sh
fi
fi
# Directories
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# Make build directory
rm -rf build
make cleanall
mkdir -p build
cd build
# Get PaddlePaddle inference model
echo -e "\e[36mDownload PaddlePaddle inference model\e[0m"
wget https://paddleocr.bj.bcebos.com/tvm/ocr_en.tar
tar -xf ocr_en.tar
......@@ -144,9 +176,9 @@ cd ${script_dir}
echo ${script_dir}
make
# Run demo executable on the FVP
FVP_Corstone_SSE-300_Ethos-U55 -C cpu0.CFGDTCMSZ=15 \
# Run demo executable on the AVH
$Platform -C cpu0.CFGDTCMSZ=15 \
-C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \
-C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \
-C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \
./build/demo
\ No newline at end of file
./build/demo --stat
......@@ -34,12 +34,13 @@ cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio, int rec_image_height) {
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
{127, 127, 127});
return resize_img;
}
std::vector<std::string> ReadDict(std::string path) {
......
......@@ -474,7 +474,7 @@ void system(char **argv){
std::vector<double> rec_times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, use_direction_classify, &rec_times);
charactor_dict, cls_predictor, use_direction_classify, &rec_times, rec_image_height);
//// visualization
auto img_vis = Visualization(srcimg, boxes);
......
......@@ -5,9 +5,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,已支持的
- [文本检测算法](./algorithm_overview.md#11-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E7%AE%97%E6%B3%95)
- [文本识别算法](./algorithm_overview.md#12-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [端到端算法](./algorithm_overview.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [表格识别]](./algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
**欢迎广大开发者合作共建,贡献更多算法,合入有奖🎁!具体可查看[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)。**
新增算法可参考如下教程:
- [使用PaddleOCR架构添加新算法](./add_new_algorithm.md)
\ No newline at end of file
- [使用PaddleOCR架构添加新算法](./add_new_algorithm.md)
# DB
# DB与DB++
- [1. 算法简介](#1)
- [2. 环境配置](#2)
......@@ -21,12 +21,24 @@
> Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang
> AAAI, 2020
> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
> Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang
> TPAMI, 2022
在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|DB++|ResNet50|[configs/det/det_r50_db++_icdar15.yml](../../configs/det/det_r50_db++_icdar15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
在TD_TR文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB++|ResNet50|[configs/det/det_r50_db++_td_tr.yml](../../configs/det/det_r50_db++_td_tr.yml)|92.92%|86.48%|89.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_td_tr_train.tar)|
<a name="2"></a>
......@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai
DB文本检测模型推理,可以执行如下命令:
```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/"
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" --det_algorithm="DB"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
......@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式:
pages={11474--11481},
year={2020}
}
```
\ No newline at end of file
@article{liao2022real,
title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion},
author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
```
# FCENet
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理](#42-c推理)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
......
# OCR算法
- [1. 两阶段算法](#1-两阶段算法)
- [1.1 文本检测算法](#11-文本检测算法)
- [1.2 文本识别算法](#12-文本识别算法)
- [2. 端到端算法](#2-端到端算法)
- [1. 两阶段算法](#1)
- [1.1 文本检测算法](#11)
- [1.2 文本识别算法](#12)
- [2. 端到端算法](#2)
- [3. 表格识别算法](#3)
本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md)
......@@ -86,8 +87,9 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
<a name="2"></a>
......@@ -95,3 +97,16 @@
已支持的端到端OCR算法列表(戳链接获取使用教程):
- [x] [PGNet](./algorithm_e2e_pgnet.md)
<a name="3"></a>
## 3. 表格识别算法
已支持的表格识别算法列表(戳链接获取使用教程):
- [x] [TableMaster](./algorithm_table_master.md)
在PubTabNet表格识别公开数据集上,算法效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# 表格识别算法-TableMASTER
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
在PubTabNet表格识别公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
上述TableMaster模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)
数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。以基于TableResNetExtra骨干网络,在PubTabNet数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
转换成功后,在目录下有三个文件:
```
./inference/table_master/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
cd ppstructure/
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
```
执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
结果如下:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**注意**
- TableMaster在推理时比较慢,建议使用GPU进行使用。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持TableMaster,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
......@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
| ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) |
| ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 |
| total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 |
| td tr |https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar| 图片下载地址中已包含 |
#### 1.2.1 ICDAR 2015
ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。
......
......@@ -7,7 +7,8 @@
- [1. 文本检测模型推理](#1-文本检测模型推理)
- [2. 文本识别模型推理](#2-文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#22-多语言模型的推理)
- [2.2 英文识别模型推理](#22-英文识别模型推理)
- [2.3 多语言模型的推理](#23-多语言模型的推理)
- [3. 方向分类模型推理](#3-方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理)
......@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg"
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379)
```
<a name="英文识别模型推理"></a>
### 2.2 英文识别模型推理
英文识别模型推理,可以执行如下命令, 注意修改字典路径:
```
# 下载英文数字识别模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
执行命令后,上图的预测结果为:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="多语言模型的推理"></a>
### 2.2 多语言模型的推理
### 2.3 多语言模型的推理
如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
......
......@@ -6,5 +6,6 @@ PaddleOCR will add cutting-edge OCR algorithms and models continuously. Check ou
- [text detection algorithms](./algorithm_overview_en.md#11)
- [text recognition algorithms](./algorithm_overview_en.md#12)
- [end-to-end algorithms](./algorithm_overview_en.md#2)
- [table recognition algorithms](./algorithm_overview_en.md#3)
Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline.
\ No newline at end of file
Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline.
# OCR Algorithms
- [1. Two-stage Algorithms](#1)
* [1.1 Text Detection Algorithms](#11)
* [1.2 Text Recognition Algorithms](#12)
- [1.1 Text Detection Algorithms](#11)
- [1.2 Text Recognition Algorithms](#12)
- [2. End-to-end Algorithms](#2)
- [3. Table Recognition Algorithms](#3)
This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md).
......@@ -85,8 +86,9 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
<a name="2"></a>
......@@ -94,3 +96,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
Supported end-to-end algorithms (Click the link to get the tutorial):
- [x] [PGNet](./algorithm_e2e_pgnet_en.md)
<a name="3"></a>
## 3. Table Recognition Algorithms
Supported table recognition algorithms (Click the link to get the tutorial):
- [x] [TableMaster](./algorithm_table_master_en.md)
On the PubTabNet dataset, the algorithm result is as follows:
|Model|Backbone|Config|Acc|Download link|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# Table Recognition Algorithm-TableMASTER
- [1. Introduction](#1-introduction)
- [2. Environment](#2-environment)
- [3. Model Training / Evaluation / Prediction](#3-model-training--evaluation--prediction)
- [4. Inference and Deployment](#4-inference-and-deployment)
- [4.1 Python Inference](#41-python-inference)
- [4.2 C++ Inference](#42-c-inference)
- [4.3 Serving](#43-serving)
- [4.4 More](#44-more)
- [5. FAQ](#5-faq)
- [Citation](#citation)
<a name="1"></a>
## 1. Introduction
Paper:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
On the PubTabNet table recognition public data set, the algorithm reproduction acc is as follows:
|Model|Backbone|Cnnfig|Acc|Download link|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
The above TableMaster model is trained using the PubTabNet table recognition public dataset. For the download of the dataset, please refer to [table_datasets](./dataset/table_datasets_en.md).
After the data download is complete, please refer to [Text Recognition Training Tutorial](./recognition_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different models.
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, convert the model saved in the TableMaster table recognition training process into an inference model. Taking the model based on the TableResNetExtra backbone network and trained on the PubTabNet dataset as example ([model download link](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**Note: **
- If you trained the model on your own dataset and adjusted the dictionary file, please pay attention to whether the `character_dict_path` in the modified configuration file is the correct dictionary file
Execute the following command for model inference:
```shell
cd ppstructure/
# When predicting all images in a folder, you can modify image_dir to a folder, such as --image_dir='docs/table'.
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
```
After executing the command, the prediction results of the above image (structural information and the coordinates of each cell in the table) are printed to the screen, and the visualization of the cell coordinates is also saved. An example is as follows:
result:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**Note**:
- TableMaster is relatively slow during inference, and it is recommended to use GPU for use.
<a name="4-2"></a>
### 4.2 C++ Inference
Since the post-processing is not written in CPP, the TableMaster does not support CPP inference.
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
......@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo
- [Text Detection Model Inference](#text-detection-model-inference)
- [Text Recognition Model Inference](#text-recognition-model-inference)
- [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference)
- [2. Multilingual Model Inference](#2-multilingual-model-inference)
- [2. English Recognition Model Inference](#2-english-recognition-model-inference)
- [3. Multilingual Model Inference](#3-multilingual-model-inference)
- [Angle Classification Model Inference](#angle-classification-model-inference)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation)
......@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score)
```bash
Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671)
```
<a name="2-english-recognition-model-inference"></a>
### 2. English Recognition Model Inference
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
For English recognition model inference, you can execute the following commands,you need to specify the dictionary path used by `--rec_char_dict_path`:
### 2. Multilingual Model Inference
```
# download en model:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
After executing the command, the prediction result of the above figure is:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="3-multilingual-model-inference"></a>
### 3. Multilingual Model Inference
If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
......
......@@ -23,9 +23,10 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......@@ -36,7 +37,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
from .table_ops import *
from .vqa import *
......
......@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object):
class KieLabelEncode(object):
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs):
def __init__(self,
character_dict_path,
class_path,
norm=10,
directed=False,
**kwargs):
super(KieLabelEncode, self).__init__()
self.dict = dict({'': 0})
self.label2classid_map = dict()
with open(character_dict_path, 'r', encoding='utf-8') as fr:
idx = 1
for line in fr:
char = line.strip()
self.dict[char] = idx
idx += 1
with open(class_path, "r") as fin:
lines = fin.readlines()
for idx, line in enumerate(lines):
line = line.strip("\n")
self.label2classid_map[line] = idx
self.norm = norm
self.directed = directed
......@@ -408,7 +419,7 @@ class KieLabelEncode(object):
text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind)
if 'label' in ann.keys():
labels.append(ann['label'])
labels.append(self.label2classid_map[ann['label']])
elif 'key_cls' in ann.keys():
labels.append(ann['key_cls'])
else:
......@@ -551,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx
class TableLabelEncode(object):
class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight=1.0,
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
self.max_text_len = max_text_length
self.lower = False
self.learn_empty_box = learn_empty_box
self.merge_no_span_structure = merge_no_span_structure
self.replace_empty_cell_token = replace_empty_cell_token
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.idx2char = {v: k for k, v in self.dict.items()}
self.character = dict_character
self.point_num = point_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.empty_bbox_token_dict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
'<eb10></eb10>',
}
@property
def _max_text_len(self):
return self.max_text_len + 2
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
structure = data['structure']
if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
structure = self._replace_empty_cell_token(structure, cells)
# remove empty token and add " " to span token
new_structure = []
for token in structure:
if token != '':
if 'span' in token and token[0] != ' ':
token = ' ' + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
)
structure = [self.start_idx] + structure + [self.end_idx
] # add sos abd eos
structure = structure + [self.pad_idx] * (self._max_text_len -
len(structure)) # pad
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
if len(structure) > self._max_text_len:
return None
# encode box
bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
for i, token in enumerate(structure):
if self.idx2char[token] in self.td_token:
if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
'tokens']) > 0:
bbox = cells[bbox_idx]['bbox'].copy()
bbox = np.array(bbox, dtype=np.float32).reshape(-1)
bboxes[i] = bbox
bbox_masks[i] = 1.0
if self.learn_empty_box:
bbox_masks[i] = 1.0
bbox_idx += 1
data['bboxes'] = bboxes
data['bbox_masks'] = bbox_masks
return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
def _merge_no_span_structure(self, structure):
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
"""
new_structure = []
i = 0
while i < len(structure):
token = structure[i]
if token == '<td>':
token = '<td></td>'
i += 1
new_structure.append(token)
i += 1
return new_structure
def _replace_empty_cell_token(self, token_list, cells):
"""
This fun code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
"""
bbox_idx = 0
add_empty_bbox_token_list = []
for token in token_list:
if token in ['<td></td>', '<td', '<td>']:
if 'bbox' not in cells[bbox_idx].keys():
content = str(cells[bbox_idx]['tokens'])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
else:
return None
return text_list
add_empty_bbox_token_list.append(token)
return add_empty_bbox_token_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class TableMasterLabelEncode(TableLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path,
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@property
def _max_text_len(self):
return self.max_text_len
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
return new_bboxes
class SARLabelEncode(BaseRecLabelEncode):
......@@ -819,6 +869,7 @@ class VQATokenLabelEncode(object):
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
use_textline_bbox_info=True,
infer_mode=False,
ocr_engine=None,
**kwargs):
......@@ -847,11 +898,51 @@ class VQATokenLabelEncode(object):
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
token_bboxes = []
curr_word_idx = 0
x1, y1, x2, y2 = bbox
unit_w = (x2 - x1) / len(text)
for idx, word in enumerate(words):
curr_w = len(word) * unit_w
word_bbox = [x1, y1, x1 + curr_w, y2]
token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
x1 += (len(word) + 1) * unit_w
return token_bboxes
def filter_empty_contents(self, ocr_info):
"""
find out the empty texts and remove the links
"""
new_ocr_info = []
empty_index = []
for idx, info in enumerate(ocr_info):
if len(info["transcription"]) > 0:
new_ocr_info.append(copy.deepcopy(info))
else:
empty_index.append(info["id"])
for idx, info in enumerate(new_ocr_info):
new_link = []
for link in info["linking"]:
if link[0] in empty_index or link[1] in empty_index:
continue
new_link.append(link)
new_ocr_info[idx]["linking"] = new_link
return new_ocr_info
def __call__(self, data):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)
height, width, _ = data['image'].shape
words_list = []
......@@ -863,8 +954,6 @@ class VQATokenLabelEncode(object):
entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
......@@ -874,17 +963,19 @@ class VQATokenLabelEncode(object):
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
text = info["transcription"]
if len(text) <= 0:
continue
if train_re:
# for re
if len(info["text"]) == 0:
if len(text) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
bbox = self._smooth_box(info["bbox"], height, width)
info["bbox"] = self.trans_poly_to_bbox(info["points"])
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
......@@ -895,6 +986,19 @@ class VQATokenLabelEncode(object):
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
if self.use_textline_bbox_info:
bbox = [info["bbox"]] * len(encode_res["input_ids"])
else:
bbox = self.split_bbox(info["bbox"], info["transcription"],
self.tokenizer)
if len(bbox) <= 0:
continue
bbox = self._smooth_box(bbox, height, width)
if self.add_special_ids:
bbox.insert(0, [0, 0, 0, 0])
bbox.append([0, 0, 0, 0])
# parse label
if not self.infer_mode:
label = info['label']
......@@ -919,7 +1023,7 @@ class VQATokenLabelEncode(object):
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
bbox_list.extend(bbox)
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
if not self.infer_mode:
......@@ -944,40 +1048,42 @@ class VQATokenLabelEncode(object):
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def trans_poly_to_bbox(self, poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
"transcription": res[1][0],
"bbox": self.trans_poly_to_bbox(res[0]),
"points": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
return info_dict["ocr_info"]
return info_dict
def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
return bbox
def _smooth_box(self, bboxes, height, width):
bboxes = np.array(bboxes)
bboxes[:, 0] = bboxes[:, 0] * 1000 / width
bboxes[:, 2] = bboxes[:, 2] * 1000 / width
bboxes[:, 1] = bboxes[:, 1] * 1000 / height
bboxes[:, 3] = bboxes[:, 3] * 1000 / height
bboxes = bboxes.astype("int64").tolist()
return bboxes
def _parse_label(self, label, encode_res):
gt_label = []
if label.lower() == "other":
if label.lower() in ["other", "others", "ignore"]:
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
......@@ -1001,7 +1107,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs)
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_out = dict()
......
......@@ -205,9 +205,12 @@ class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape']
self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min')
......@@ -237,6 +240,10 @@ class DetResizeForTest(object):
def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h)))
......
......@@ -32,7 +32,7 @@ class GenTableMask(object):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
......@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
......@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
......@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
......@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_split_list, w_projection_map = self.projection(word_img.T,
word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
......@@ -170,75 +176,54 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
left, top, right, bottom = self.shrink_bbox(
[left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
**kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
def __call__(self, data):
img = data['image']
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
resize_img = cv2.resize(img, (resize_w, resize_h))
if self.resize_bboxes and not self.infer_mode:
data['bboxes'] = data['bboxes'] * ratio
data['image'] = resize_img
data['src_img'] = img
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
pad_h, pad_w = self.size
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
shape = data['shape'].tolist()
shape.extend([pad_h, pad_w])
data['shape'] = np.array(shape)
return data
\ No newline at end of file
......@@ -13,7 +13,12 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .augment import DistortBBox
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
'DistortBBox',
]
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import random
class DistortBBox:
def __init__(self, prob=0.5, max_scale=1, **kwargs):
"""Random distort bbox
"""
self.prob = prob
self.max_scale = max_scale
def __call__(self, data):
if random.random() > self.prob:
return data
bbox = np.array(data['bbox'])
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
data['bbox'] = np.clip(data['bbox'], 0, 1000)
data['bbox'] = bbox.tolist()
sys.stdout.flush()
return data
......@@ -16,6 +16,7 @@ import os
import random
from paddle.io import Dataset
import json
from copy import deepcopy
from .imaug import transform, create_operators
......@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.mode = mode.lower()
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
# self.check(config['Global']['max_text_length'])
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
self.logger.warning("{} does not exist!".format(img_path))
continue
if len(structure) == 0 or len(structure) > max_text_length:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines.append(line)
self.data_lines = data_lines
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
......@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
data = {
'img_path': img_path,
'cells': cells,
'structure': structure,
'file_name': file_name
}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except:
import traceback
err = traceback.format_exc()
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
data_line, err))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return len(self.data_idx_order_list)
return len(self.data_lines)
......@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
......@@ -61,7 +61,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -57,17 +57,24 @@ class CELoss(nn.Layer):
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
else:
raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
......@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer):
self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js")
self.jskl_loss = KLJSLoss(mode="kl")
def _kldiv(self, x, target):
eps = 1.0e-10
......
......@@ -20,15 +20,21 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
......@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious
ious = inters / uni
......@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
if len(batch) == 6:
structure_mask = batch[5].astype("int64")
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_probs = paddle.reshape(structure_probs,
[-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else:
total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/tree/master/mmocr/models/textrecog/losses
"""
import paddle
from paddle import nn
class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction='mean')
self.box_loss = nn.L1Loss(reduction='sum')
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
structure_probs = predicts['structure_probs']
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
structure_probs = structure_probs.reshape(
[-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
structure_loss = structure_loss.mean()
losses = dict(structure_loss=structure_loss)
# box loss
bboxes_preds = predicts['loc_preds']
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds = bboxes_preds * bbox_masks
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
masked_bboxes_targets[:, :, 0::2])
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
masked_bboxes_targets[:, :, 1::2])
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
losses.update({
'loss': all_loss,
'horizon_bbox_loss': horizon_loss,
'vertical_bbox_loss': vertical_loss
})
return losses
......@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
labels = batch[5]
attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
......
......@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog = ""
# print(len(gt))
for n in range(len(gt)):
points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
if not Polygon(points).is_valid:
continue
gtPol = points
......@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)):
points = pred[n]['points']
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
if not Polygon(points).is_valid:
continue
detPol = points
......@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (
methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = {
'precision': methodPrecision,
'recall': methodRecall,
......
......@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppocr.metrics.det_metric import DetMetric
class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
class TableStructureMetric(object):
def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.eps = eps
self.reset()
def __call__(self, pred, batch, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy()
structure_labels = batch[1]
def __call__(self, pred_label, batch=None, *args, **kwargs):
preds, labels = pred_label
pred_structure_batch_list = preds['structure_batch_list']
gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0
all_num = 0
structure_probs = np.argmax(structure_probs, axis=2)
structure_labels = structure_labels[:, 1:]
batch_size = structure_probs.shape[0]
for bno in range(batch_size):
all_num += 1
if (structure_probs[bno] == structure_labels[bno]).all():
for (pred, pred_conf), target in zip(pred_structure_batch_list,
gt_structure_batch_list):
pred_str = ''.join(pred)
target_str = ''.join(target)
if pred_str == target_str:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self):
"""
......@@ -49,3 +50,89 @@ class TableMetric(object):
def reset(self):
self.correct_num = 0
self.all_num = 0
self.len_acc_num = 0
self.token_nums = 0
self.anys_dict = dict()
class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
**kwargs):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
self.structure_metric(pred_label)
if self.bbox_metric is not None:
self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
def prepare_bbox_metric_input(self, pred_label):
pred_bbox_batch_list = []
gt_ignore_tags_batch_list = []
gt_bbox_batch_list = []
preds, labels = pred_label
batch_num = len(preds['bbox_batch_list'])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
for pred_box in preds['bbox_batch_list'][batch_idx]
]
pred_bbox_batch_list.append({'points': pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
for gt_box in labels['bbox_batch_list'][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
return [
pred_bbox_batch_list,
[0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
]
def get_metric(self):
structure_metric = self.structure_metric.get_metric()
if self.bbox_metric is None:
return structure_metric
bbox_metric = self.bbox_metric.get_metric()
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
output["structure_metric_{}".format(
sub_key)] = structure_metric[sub_key]
else:
output = structure_metric
for sub_key in bbox_metric:
output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
return output
def reset(self):
self.structure_metric.reset()
if self.bbox_metric is not None:
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
......@@ -37,23 +37,26 @@ class VQAReTokenMetric(object):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
if "head" in self.relations_list[b]:
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (
self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (
self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
......
......@@ -18,9 +18,13 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type):
if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet
from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
from .det_resnet_vd import DeformableConvV2, ConvBNLayer
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
is_dcn=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
act="relu", )
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu",
is_dcn=is_dcn,
dcn_groups=1, )
self.conv2 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
act=None, )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
stride=stride, )
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu")
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
out_indices=None,
dcn_stage=None):
super(ResNet, self).__init__()
self.layers = layers
self.input_image_channel = in_channels
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv = ConvBNLayer(
in_channels=self.input_image_channel,
out_channels=64,
kernel_size=7,
stride=2,
act="relu", )
self.pool2d_max = MaxPool2D(
kernel_size=3,
stride=2,
padding=1, )
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list = []
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
is_dcn=is_dcn))
block_list.append(bottleneck_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
block_list.append(basic_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out
......@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
__all__ = ["ResNet"]
__all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
class DeformableConvV2(nn.Layer):
......@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size,
stride=1,
groups=1,
dcn_groups=1,
is_vd_mode=False,
act=None,
is_dcn=False):
......@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=2, #groups,
groups=dcn_groups, #groups,
bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act)
......@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size=3,
stride=stride,
act='relu',
is_dcn=is_dcn)
is_dcn=is_dcn,
dcn_groups=2)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
......@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return y
class ResNet(nn.Layer):
class ResNet_vd(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
dcn_stage=None,
out_indices=None,
**kwargs):
super(ResNet, self).__init__()
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
......@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for block in range(len(depth)):
block_list = []
shortcut = False
# is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
gcb_ratio = gcb_config['ratio']
gcb_headers = gcb_config['headers']
att_scale = gcb_config['att_scale']
fusion_type = gcb_config['fusion_type']
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
fusion_type=fusion_type)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.gcb_config is not None:
out = self.context_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def get_gcb_config(gcb_config, layer):
if gcb_config is None or not gcb_config['layers'][layer]:
return None
else:
return gcb_config
class TableResNetExtra(nn.Layer):
def __init__(self, layers, in_channels=3, gcb_config=None):
assert len(layers) >= 4
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
in_channels,
64,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer1 = self._make_layer(
BasicBlock,
256,
layers[0],
stride=1,
gcb_config=get_gcb_config(gcb_config, 0))
self.conv3 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer2 = self._make_layer(
BasicBlock,
256,
layers[1],
stride=1,
gcb_config=get_gcb_config(gcb_config, 1))
self.conv4 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer3 = self._make_layer(
BasicBlock,
512,
layers[2],
stride=1,
gcb_config=get_gcb_config(gcb_config, 2))
self.conv5 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
self.layer4 = self._make_layer(
BasicBlock,
512,
layers[3],
stride=1,
gcb_config=get_gcb_config(gcb_config, 3))
self.conv6 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
self.out_channels = [256, 256, 512]
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
gcb_config=gcb_config))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
f = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
f.append(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
f.append(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.layer4(x)
x = self.conv6(x)
x = self.bn6(x)
x = self.relu6(x)
f.append(x)
return f
class MultiAspectGCAttention(nn.Layer):
def __init__(self,
inplanes,
ratio,
headers,
pooling_type='att',
att_scale=False,
fusion_type='channel_add'):
super(MultiAspectGCAttention, self).__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.att_scale = False
self.single_header_inplanes = int(inplanes / headers)
if pooling_type == 'att':
self.conv_mask = nn.Conv2D(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2D(
2 * self.inplanes, self.inplanes, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
def spatial_pool(self, x):
batch, channel, height, width = x.shape
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.reshape([
batch * self.headers, self.single_header_inplanes, height, width
])
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x = input_x.reshape([
batch * self.headers, self.single_header_inplanes,
height * width
])
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
[batch * self.headers, 1, height * width])
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / paddle.sqrt(
self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = paddle.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.reshape(
[batch, self.headers * self.single_header_inplanes, 1, 1])
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = paddle.concat(
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
return out
......@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer):
super(NLPBaseModel, self).__init__()
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
self.model = model_class.from_pretrained(pretrained)
else:
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
......@@ -74,9 +76,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
position_ids=None,
output_hidden_states=False)
return x
......@@ -96,13 +98,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None)
if not self.training:
return x
return x[0]
......@@ -120,13 +124,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None)
if not self.training:
return x
return x[0]
......@@ -140,12 +146,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None,
entities=x[5],
relations=x[6])
return x
......@@ -161,12 +167,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
position_ids=None,
head_mask=None,
labels=None,
entities=x[5],
relations=x[6])
return x
......@@ -42,12 +42,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead
from .table_master_head import TableMasterHead
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead'
'MultiHead', 'ABINetHead', 'TableMasterHead'
]
#table head
......
......@@ -273,7 +273,8 @@ def _get_length(logit):
out = out.cast('int32')
out = out.argmax(-1)
out = out + 1
out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1]))
len_seq = paddle.zeros_like(out) + logit.shape[1]
out = paddle.where(abn, out, len_seq)
return out
......
......@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from .rec_att_head import AttentionGRUCell
class TableAttentionHead(nn.Layer):
def __init__(self,
......@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer):
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
max_text_length=800,
out_channels=30,
point_num=2,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.out_channels = out_channels
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type
self.in_max_len = in_max_len
......@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
......@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer):
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length + 1):
for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
max_text_length = paddle.to_tensor(self.max_text_length)
i = 0
while i < max_elem_length + 1:
while i < max_text_length + 1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
"""
import copy
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
class TableMasterHead(nn.Layer):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def __init__(self,
in_channels,
out_channels=30,
headers=8,
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
self.layers = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
self.cls_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.bbox_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
tgt_mask = paddle.logical_and(
trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
out_step, bbox_output_step = self.decode(input, feature, None,
target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
input = paddle.concat(
[input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
return output, bbox_output
def forward_train(self, out_enc, targets):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
src_mask, tgt_mask)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
output = F.softmax(output)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
b, c, h, w = feat.shape
feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
feat = feat.transpose((0, 2, 1))
out_enc = self.positional_encoding(feat)
if self.training:
return self.forward_train(out_enc, targets)
return self.forward_test(out_enc)
class DecoderLayer(nn.Layer):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def __init__(self, headers, d_model, dropout, d_ff):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(headers, d_model, dropout)
self.src_attn = MultiHeadAttention(headers, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](
x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
class MultiHeadAttention(nn.Layer):
def __init__(self, headers, d_model, dropout):
super(MultiHeadAttention, self).__init__()
assert d_model % headers == 0
self.d_k = int(d_model / headers)
self.headers = headers
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
class FeedForward(nn.Layer):
def __init__(self, d_model, d_ff, dropout):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SubLayerConnection(nn.Layer):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SubLayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def masked_fill(x, mask, value):
mask = mask.astype(x.dtype)
return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
def self_attention(query, key, value, mask=None, dropout=None):
"""
Compute 'Scale Dot Product Attention'
"""
d_k = value.shape[-1]
score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
if mask is not None:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score = masked_fill(score, mask == 0, -6.55e4) # for fp16
p_attn = F.softmax(score, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
def clones(module, N):
""" Produce N identical layers """
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, *input):
x = input[0]
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Layer):
""" Implement the PE function. """
def __init__(self, d_model, dropout=0., max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
div_term = paddle.exp(
paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
return self.dropout(feat)
......@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs):
def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__()
self.out_channels = out_channels
self.use_asf = use_asf
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D(
......@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False)
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
def forward(self, x):
c2, c3, c4, c5 = x
......@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
if self.use_asf is True:
fuse = self.asf(fuse, [p5, p4, p3, p2])
return fuse
......@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse
class ASFBlock(nn.Layer):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def __init__(self, in_channels, inter_channels, out_features_num=4):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super(ASFBlock, self).__init__()
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.out_features_num = out_features_num
self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
#Nx1xHxW
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=3,
bias_attr=False,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.ReLU(),
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
self.channel_scale = nn.Sequential(
nn.Conv2D(
in_channels=inter_channels,
out_channels=out_features_num,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
spatial_x = paddle.mean(fuse_features, axis=1, keepdim=True)
attention_scores = self.spatial_scale(spatial_x) + fuse_features
attention_scores = self.channel_scale(attention_scores)
assert len(features_list) == self.out_features_num
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return paddle.concat(out_list, axis=1)
......@@ -308,3 +308,81 @@ class Const(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
class DecayLearningRate(object):
"""
DecayLearningRate learning rate decay
new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
Args:
learning_rate(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
end_lr(float): The minimum final learning rate. Default: 0.0.
"""
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
factor=0.9,
end_lr=0,
**kwargs):
super(DecayLearningRate, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs + 1
self.factor = factor
self.end_lr = 0
self.decay_steps = step_each_epoch * epochs
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.decay_steps,
power=self.factor,
end_lr=self.end_lr)
return learning_rate
class MultiStepDecay(object):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
milestones,
step_each_epoch,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
......@@ -26,12 +26,13 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None):
......@@ -42,7 +43,8 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode'
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0,
use_dilation=False,
score_mode="fast",
use_polygon=False,
**kwargs):
self.thresh = thresh
self.box_thresh = box_thresh
......@@ -45,6 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio
self.min_size = 3
self.score_mode = score_mode
self.use_polygon = use_polygon
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
......@@ -52,6 +54,53 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
......@@ -85,7 +134,7 @@ class DBPostProcess(object):
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
......@@ -99,8 +148,7 @@ class DBPostProcess(object):
scores.append(score)
return np.array(boxes, dtype=np.int16), scores
def unclip(self, box):
unclip_ratio = self.unclip_ratio
def unclip(self, box, unclip_ratio):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
......@@ -185,8 +233,12 @@ class DBPostProcess(object):
self.dilation_kernel)
else:
mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
if self.use_polygon is True:
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
else:
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes})
return boxes_batch
......@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
use_polygon=False,
**kwargs):
self.model_name = model_name
self.key = key
......@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object):
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode)
score_mode=score_mode,
use_polygon=use_polygon)
def __call__(self, predicts, shape_list):
results = {}
......
......@@ -58,6 +58,8 @@ class PSEPostProcess(object):
kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :]
text_mask = paddle.unsqueeze(text_mask, axis=1)
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy()
......
......@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
bbox_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
bbox_preds = bbox_preds.numpy()
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index.
"""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index.
"""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
], 'The shape used for box normalization must be ori or pad'
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
def get_ignored_tokens(self):
pad_idx = self.dict[self.pad_str]
start_idx = self.dict[self.beg_str]
end_idx = self.dict[self.end_str]
unknown_idx = self.dict[self.unknown_str]
return [start_idx, end_idx, pad_idx, unknown_idx]
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
if self.box_shape == 'pad':
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
return bbox
......@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[0]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
return self._metric(preds, batch[1])
return self._metric(preds, batch[5])
else:
return self._infer(preds, **kwargs)
......@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip(
preds, attention_masks, segment_offset_ids, ocr_infos):
for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
......
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
277 28 1267 1186
<b>
V
a
r
i
b
l
e
</b>
H
z
d
t
o
9
5
%
C
I
<i>
p
</i>
v
u
*
A
g
(
m
n
)
0
.
7
1
6
>
8
3
2
G
4
M
F
T
y
f
s
L
w
c
U
h
D
S
Q
R
x
P
-
E
O
/
k
,
+
N
K
q
[
]
<
<sup>
</sup>
μ
±
J
j
W
_
Δ
B
:
Y
α
λ
;
<sub>
</sub>
?
=
°
#
̊
̈
̂
Z
X
β
'
~
@
"
γ
&
χ
σ
§
|
×
$
\
π
®
^
<underline>
</underline>
́
·
£
φ
Ψ
ß
η
̃
Φ
ρ
̄
δ
̧
Ω
{
}
̀
ø
κ
ε
¥
`
ω
Σ
Β
̸
Χ
Α
ψ
ǂ
ζ
!
Γ
θ
υ
τ
Ø
©
С
˂
ɛ
¢
˃
­
Π
̌
<overline>
</overline>
¤
̆
ξ
÷

ι
ν
<strike>
</strike>
«
»
ł
ı
Θ
̇
æ
ʹ
ˆ
̨
Ι
Λ
А
<thead>
<tr>
<td>
......@@ -303,2457 +25,4 @@ $
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
0 2924682
1 3405345
2 2363468
3 2709165
4 4078680
5 3250792
6 1923159
7 1617890
8 1450532
9 1717624
10 1477550
11 1489223
12 915528
13 819193
14 593660
15 518924
16 682065
17 494584
18 400591
19 396421
20 340994
21 280688
22 250328
23 226786
24 199927
25 182707
26 164629
27 141613
28 127554
29 116286
30 107682
31 96367
32 88002
33 79234
34 72186
35 65921
36 60374
37 55976
38 52166
39 47414
40 44932
41 41279
42 38232
43 35463
44 33703
45 30557
46 29639
47 27000
48 25447
49 23186
50 22093
51 20412
52 19844
53 18261
54 17561
55 16499
56 15597
57 14558
58 14372
59 13445
60 13514
61 12058
62 11145
63 10767
64 10370
65 9630
66 9337
67 8881
68 8727
69 8060
70 7994
71 7740
72 7189
73 6729
74 6749
75 6548
76 6321
77 5957
78 5740
79 5407
80 5370
81 5035
82 4921
83 4656
84 4600
85 4519
86 4277
87 4023
88 3939
89 3910
90 3861
91 3560
92 3483
93 3406
94 3346
95 3229
96 3122
97 3086
98 3001
99 2884
100 2822
101 2677
102 2670
103 2610
104 2452
105 2446
106 2400
107 2300
108 2316
109 2196
110 2089
111 2083
112 2041
113 1881
114 1838
115 1896
116 1795
117 1786
118 1743
119 1765
120 1750
121 1683
122 1563
123 1499
124 1513
125 1462
126 1388
127 1441
128 1417
129 1392
130 1306
131 1321
132 1274
133 1294
134 1240
135 1126
136 1157
137 1130
138 1084
139 1130
140 1083
141 1040
142 980
143 1031
144 974
145 980
146 932
147 898
148 960
149 907
150 852
151 912
152 859
153 847
154 876
155 792
156 791
157 765
158 788
159 787
160 744
161 673
162 683
163 697
164 666
165 680
166 632
167 677
168 657
169 618
170 587
171 585
172 567
173 549
174 562
175 548
176 542
177 539
178 542
179 549
180 547
181 526
182 525
183 514
184 512
185 505
186 515
187 467
188 475
189 458
190 435
191 443
192 427
193 424
194 404
195 389
196 429
197 404
198 386
199 351
200 388
201 408
202 361
203 346
204 324
205 361
206 363
207 364
208 323
209 336
210 342
211 315
212 325
213 328
214 314
215 327
216 320
217 300
218 295
219 315
220 310
221 295
222 275
223 248
224 274
225 232
226 293
227 259
228 286
229 263
230 242
231 214
232 261
233 231
234 211
235 250
236 233
237 206
238 224
239 210
240 233
241 223
242 216
243 222
244 207
245 212
246 196
247 205
248 201
249 202
250 211
251 201
252 215
253 179
254 163
255 179
256 191
257 188
258 196
259 150
260 154
261 176
262 211
263 166
264 171
265 165
266 149
267 182
268 159
269 161
270 164
271 161
272 141
273 151
274 127
275 129
276 142
277 158
278 148
279 135
280 127
281 134
282 138
283 131
284 126
285 125
286 130
287 126
288 135
289 125
290 135
291 131
292 95
293 135
294 106
295 117
296 136
297 128
298 128
299 118
300 109
301 112
302 117
303 108
304 120
305 100
306 95
307 108
308 112
309 77
310 120
311 104
312 109
313 89
314 98
315 82
316 98
317 93
318 77
319 93
320 77
321 98
322 93
323 86
324 89
325 73
326 70
327 71
328 77
329 87
330 77
331 93
332 100
333 83
334 72
335 74
336 69
337 77
338 68
339 78
340 90
341 98
342 75
343 80
344 63
345 71
346 83
347 66
348 71
349 70
350 62
351 62
352 59
353 63
354 62
355 52
356 64
357 64
358 56
359 49
360 57
361 63
362 60
363 68
364 62
365 55
366 54
367 40
368 75
369 70
370 53
371 58
372 57
373 55
374 69
375 57
376 53
377 43
378 45
379 47
380 56
381 51
382 59
383 51
384 43
385 34
386 57
387 49
388 39
389 46
390 48
391 43
392 40
393 54
394 50
395 41
396 43
397 33
398 27
399 49
400 44
401 44
402 38
403 30
404 32
405 37
406 39
407 42
408 53
409 39
410 34
411 31
412 32
413 52
414 27
415 41
416 34
417 36
418 50
419 35
420 32
421 33
422 45
423 35
424 40
425 29
426 41
427 40
428 39
429 32
430 31
431 34
432 29
433 27
434 26
435 22
436 34
437 28
438 30
439 38
440 35
441 36
442 36
443 27
444 24
445 33
446 31
447 25
448 33
449 27
450 32
451 46
452 31
453 35
454 35
455 34
456 26
457 21
458 25
459 26
460 24
461 27
462 33
463 30
464 35
465 21
466 32
467 19
468 27
469 16
470 28
471 26
472 27
473 26
474 25
475 25
476 27
477 20
478 28
479 22
480 23
481 16
482 25
483 27
484 19
485 23
486 19
487 15
488 15
489 23
490 24
491 19
492 20
493 18
494 17
495 30
496 28
497 20
498 29
499 17
500 19
501 21
502 15
503 24
504 15
505 19
506 25
507 16
508 23
509 26
510 21
511 15
512 12
513 16
514 18
515 24
516 26
517 18
518 8
519 25
520 14
521 8
522 24
523 20
524 18
525 15
526 13
527 17
528 18
529 22
530 21
531 9
532 16
533 17
534 13
535 17
536 15
537 13
538 20
539 13
540 19
541 29
542 10
543 8
544 18
545 13
546 9
547 18
548 10
549 18
550 18
551 9
552 9
553 15
554 13
555 15
556 14
557 14
558 18
559 8
560 13
561 9
562 7
563 12
564 6
565 9
566 9
567 18
568 9
569 10
570 13
571 14
572 13
573 21
574 8
575 16
576 12
577 9
578 16
579 17
580 22
581 6
582 14
583 13
584 15
585 11
586 13
587 5
588 12
589 13
590 15
591 13
592 15
593 12
594 7
595 18
596 12
597 13
598 13
599 13
600 12
601 12
602 10
603 11
604 6
605 6
606 2
607 9
608 8
609 12
610 9
611 12
612 13
613 12
614 14
615 9
616 8
617 9
618 14
619 13
620 12
621 6
622 8
623 8
624 8
625 12
626 8
627 7
628 5
629 8
630 12
631 6
632 10
633 10
634 7
635 8
636 9
637 6
638 9
639 4
640 12
641 4
642 3
643 11
644 10
645 6
646 12
647 12
648 4
649 4
650 9
651 8
652 6
653 5
654 14
655 10
656 11
657 8
658 5
659 5
660 9
661 13
662 4
663 5
664 9
665 11
666 12
667 7
668 13
669 2
670 1
671 7
672 7
673 7
674 10
675 9
676 6
677 5
678 7
679 6
680 3
681 3
682 4
683 9
684 8
685 5
686 3
687 11
688 9
689 2
690 6
691 5
692 9
693 5
694 6
695 5
696 9
697 8
698 3
699 7
700 5
701 9
702 8
703 7
704 2
705 3
706 7
707 6
708 6
709 10
710 2
711 10
712 6
713 7
714 5
715 6
716 4
717 6
718 8
719 4
720 6
721 7
722 5
723 7
724 3
725 10
726 10
727 3
728 7
729 7
730 5
731 2
732 1
733 5
734 1
735 5
736 6
737 2
738 2
739 3
740 7
741 2
742 7
743 4
744 5
745 4
746 5
747 3
748 1
749 4
750 4
751 2
752 4
753 6
754 6
755 6
756 3
757 2
758 5
759 5
760 3
761 4
762 2
763 1
764 8
765 3
766 4
767 3
768 1
769 5
770 3
771 3
772 4
773 4
774 1
775 3
776 2
777 2
778 3
779 3
780 1
781 4
782 3
783 4
784 6
785 3
786 5
787 4
788 2
789 4
790 5
791 4
792 6
794 4
795 1
796 1
797 4
798 2
799 3
800 3
801 1
802 5
803 5
804 3
805 3
806 3
807 4
808 4
809 2
811 5
812 4
813 6
814 3
815 2
816 2
817 3
818 5
819 3
820 1
821 1
822 4
823 3
824 4
825 8
826 3
827 5
828 5
829 3
830 6
831 3
832 4
833 8
834 5
835 3
836 3
837 2
838 4
839 2
840 1
841 3
842 2
843 1
844 3
846 4
847 4
848 3
849 3
850 2
851 3
853 1
854 4
855 4
856 2
857 4
858 1
859 2
860 5
861 1
862 1
863 4
864 2
865 2
867 5
868 1
869 4
870 1
871 1
872 1
873 2
875 5
876 3
877 1
878 3
879 3
880 3
881 2
882 1
883 6
884 2
885 2
886 1
887 1
888 3
889 2
890 2
891 3
892 1
893 3
894 1
895 5
896 1
897 3
899 2
900 2
902 1
903 2
904 4
905 4
906 3
907 1
908 1
909 2
910 5
911 2
912 3
914 1
915 1
916 2
918 2
919 2
920 4
921 4
922 1
923 1
924 4
925 5
926 1
928 2
929 1
930 1
931 1
932 1
933 1
934 2
935 1
936 1
937 1
938 2
939 1
941 1
942 4
944 2
945 2
946 2
947 1
948 1
950 1
951 2
953 1
954 2
955 1
956 1
957 2
958 1
960 3
962 4
963 1
964 1
965 3
966 2
967 2
968 1
969 3
970 3
972 1
974 4
975 3
976 3
977 2
979 2
980 1
981 1
983 5
984 1
985 3
986 1
987 2
988 4
989 2
991 2
992 2
993 1
994 1
996 2
997 2
998 1
999 3
1000 2
1001 1
1002 3
1003 3
1004 2
1005 3
1006 1
1007 2
1009 1
1011 1
1013 3
1014 1
1016 2
1017 1
1018 1
1019 1
1020 4
1021 1
1022 2
1025 1
1026 1
1027 2
1028 1
1030 1
1031 2
1032 4
1034 3
1035 2
1036 1
1038 1
1039 1
1040 1
1041 1
1042 2
1043 1
1044 2
1045 4
1048 1
1050 1
1051 1
1052 2
1054 1
1055 3
1056 2
1057 1
1059 1
1061 2
1063 1
1064 1
1065 1
1066 1
1067 1
1068 1
1069 2
1074 1
1075 1
1077 1
1078 1
1079 1
1082 1
1085 1
1088 1
1090 1
1091 1
1092 2
1094 2
1097 2
1098 1
1099 2
1101 2
1102 1
1104 1
1105 1
1107 1
1109 1
1111 2
1112 1
1114 2
1115 2
1116 2
1117 1
1118 1
1119 1
1120 1
1122 1
1123 1
1127 1
1128 3
1132 2
1138 3
1142 1
1145 4
1150 1
1153 2
1154 1
1158 1
1159 1
1163 1
1165 1
1169 2
1174 1
1176 1
1177 1
1178 2
1179 1
1180 2
1181 1
1182 1
1183 2
1185 1
1187 1
1191 2
1193 1
1195 3
1196 1
1201 3
1203 1
1206 1
1210 1
1213 1
1214 1
1215 2
1218 1
1220 1
1221 1
1225 1
1226 1
1233 2
1241 1
1243 1
1249 1
1250 2
1251 1
1254 1
1255 2
1260 1
1268 1
1270 1
1273 1
1274 1
1277 1
1284 1
1287 1
1291 1
1292 2
1294 1
1295 2
1297 1
1298 1
1301 1
1307 1
1308 3
1311 2
1313 1
1316 1
1321 1
1324 1
1325 1
1330 1
1333 1
1334 1
1338 2
1340 1
1341 1
1342 1
1343 1
1345 1
1355 1
1357 1
1360 2
1375 1
1376 1
1380 1
1383 1
1387 1
1389 1
1393 1
1394 1
1396 1
1398 1
1410 1
1414 1
1419 1
1425 1
1434 1
1435 1
1438 1
1439 1
1447 1
1455 2
1460 1
1461 1
1463 1
1466 1
1470 1
1473 1
1478 1
1480 1
1483 1
1484 1
1485 2
1492 2
1499 1
1509 1
1512 1
1513 1
1523 1
1524 1
1525 2
1529 1
1539 1
1544 1
1568 1
1584 1
1591 1
1598 1
1600 1
1604 1
1614 1
1617 1
1621 1
1622 1
1626 1
1638 1
1648 1
1658 1
1661 1
1679 1
1682 1
1693 1
1700 1
1705 1
1707 1
1722 1
1728 1
1758 1
1762 1
1763 1
1775 1
1776 1
1801 1
1810 1
1812 1
1827 1
1834 1
1846 1
1847 1
1848 1
1851 1
1862 1
1866 1
1877 2
1884 1
1888 1
1903 1
1912 1
1925 1
1938 1
1955 1
1998 1
2054 1
2058 1
2065 1
2069 1
2076 1
2089 1
2104 1
2111 1
2133 1
2138 1
2156 1
2204 1
2212 1
2237 1
2246 2
2298 1
2304 1
2360 1
2400 1
2481 1
2544 1
2586 1
2622 1
2666 1
2682 1
2725 1
2920 1
3997 1
4019 1
5211 1
12 19
14 1
16 401
18 2
20 421
22 557
24 625
26 50
28 4481
30 52
32 550
34 5840
36 4644
38 87
40 5794
41 33
42 571
44 11805
46 4711
47 7
48 597
49 12
50 678
51 2
52 14715
53 3
54 7322
55 3
56 508
57 39
58 3486
59 11
60 8974
61 45
62 1276
63 4
64 15693
65 15
66 657
67 13
68 6409
69 10
70 3188
71 25
72 1889
73 27
74 10370
75 9
76 12432
77 23
78 520
79 15
80 1534
81 29
82 2944
83 23
84 12071
85 36
86 1502
87 10
88 10978
89 11
90 889
91 16
92 4571
93 17
94 7855
95 21
96 2271
97 33
98 1423
99 15
100 11096
101 21
102 4082
103 13
104 5442
105 25
106 2113
107 26
108 3779
109 43
110 1294
111 29
112 7860
113 29
114 4965
115 22
116 7898
117 25
118 1772
119 28
120 1149
121 38
122 1483
123 32
124 10572
125 25
126 1147
127 31
128 1699
129 22
130 5533
131 22
132 4669
133 34
134 3777
135 10
136 5412
137 21
138 855
139 26
140 2485
141 46
142 1970
143 27
144 6565
145 40
146 933
147 15
148 7923
149 16
150 735
151 23
152 1111
153 33
154 3714
155 27
156 2445
157 30
158 3367
159 10
160 4646
161 27
162 990
163 23
164 5679
165 25
166 2186
167 17
168 899
169 32
170 1034
171 22
172 6185
173 32
174 2685
175 17
176 1354
177 38
178 1460
179 15
180 3478
181 20
182 958
183 20
184 6055
185 23
186 2180
187 15
188 1416
189 30
190 1284
191 22
192 1341
193 21
194 2413
195 18
196 4984
197 13
198 830
199 22
200 1834
201 19
202 2238
203 9
204 3050
205 22
206 616
207 17
208 2892
209 22
210 711
211 30
212 2631
213 19
214 3341
215 21
216 987
217 26
218 823
219 9
220 3588
221 20
222 692
223 7
224 2925
225 31
226 1075
227 16
228 2909
229 18
230 673
231 20
232 2215
233 14
234 1584
235 21
236 1292
237 29
238 1647
239 25
240 1014
241 30
242 1648
243 19
244 4465
245 10
246 787
247 11
248 480
249 25
250 842
251 15
252 1219
253 23
254 1508
255 8
256 3525
257 16
258 490
259 12
260 1678
261 14
262 822
263 16
264 1729
265 28
266 604
267 11
268 2572
269 7
270 1242
271 15
272 725
273 18
274 1983
275 13
276 1662
277 19
278 491
279 12
280 1586
281 14
282 563
283 10
284 2363
285 10
286 656
287 14
288 725
289 28
290 871
291 9
292 2606
293 12
294 961
295 9
296 478
297 13
298 1252
299 10
300 736
301 19
302 466
303 13
304 2254
305 12
306 486
307 14
308 1145
309 13
310 955
311 13
312 1235
313 13
314 931
315 14
316 1768
317 11
318 330
319 10
320 539
321 23
322 570
323 12
324 1789
325 13
326 884
327 5
328 1422
329 14
330 317
331 11
332 509
333 13
334 1062
335 12
336 577
337 27
338 378
339 10
340 2313
341 9
342 391
343 13
344 894
345 17
346 664
347 9
348 453
349 6
350 363
351 15
352 1115
353 13
354 1054
355 8
356 1108
357 12
358 354
359 7
360 363
361 16
362 344
363 11
364 1734
365 12
366 265
367 10
368 969
369 16
370 316
371 12
372 757
373 7
374 563
375 15
376 857
377 9
378 469
379 9
380 385
381 12
382 921
383 15
384 764
385 14
386 246
387 6
388 1108
389 14
390 230
391 8
392 266
393 11
394 641
395 8
396 719
397 9
398 243
399 4
400 1108
401 7
402 229
403 7
404 903
405 7
406 257
407 12
408 244
409 3
410 541
411 6
412 744
413 8
414 419
415 8
416 388
417 19
418 470
419 14
420 612
421 6
422 342
423 3
424 1179
425 3
426 116
427 14
428 207
429 6
430 255
431 4
432 288
433 12
434 343
435 6
436 1015
437 3
438 538
439 10
440 194
441 6
442 188
443 15
444 524
445 7
446 214
447 7
448 574
449 6
450 214
451 5
452 635
453 9
454 464
455 5
456 205
457 9
458 163
459 2
460 558
461 4
462 171
463 14
464 444
465 11
466 543
467 5
468 388
469 6
470 141
471 4
472 647
473 3
474 210
475 4
476 193
477 7
478 195
479 7
480 443
481 10
482 198
483 3
484 816
485 6
486 128
487 9
488 215
489 9
490 328
491 7
492 158
493 11
494 335
495 8
496 435
497 6
498 174
499 1
500 373
501 5
502 140
503 7
504 330
505 9
506 149
507 5
508 642
509 3
510 179
511 3
512 159
513 8
514 204
515 7
516 306
517 4
518 110
519 5
520 326
521 6
522 305
523 6
524 294
525 7
526 268
527 5
528 149
529 4
530 133
531 2
532 513
533 10
534 116
535 5
536 258
537 4
538 113
539 4
540 138
541 6
542 116
544 485
545 4
546 93
547 9
548 299
549 3
550 256
551 6
552 92
553 3
554 175
555 6
556 253
557 7
558 95
559 2
560 128
561 4
562 206
563 2
564 465
565 3
566 69
567 3
568 157
569 7
570 97
571 8
572 118
573 5
574 130
575 4
576 301
577 6
578 177
579 2
580 397
581 3
582 80
583 1
584 128
585 5
586 52
587 2
588 72
589 1
590 84
591 6
592 323
593 11
594 77
595 5
596 205
597 1
598 244
599 4
600 69
601 3
602 89
603 5
604 254
605 6
606 147
607 3
608 83
609 3
610 77
611 3
612 194
613 1
614 98
615 3
616 243
617 3
618 50
619 8
620 188
621 4
622 67
623 4
624 123
625 2
626 50
627 1
628 239
629 2
630 51
631 4
632 65
633 5
634 188
636 81
637 3
638 46
639 3
640 103
641 1
642 136
643 3
644 188
645 3
646 58
648 122
649 4
650 47
651 2
652 155
653 4
654 71
655 1
656 71
657 3
658 50
659 2
660 177
661 5
662 66
663 2
664 183
665 3
666 50
667 2
668 53
669 2
670 115
672 66
673 2
674 47
675 1
676 197
677 2
678 46
679 3
680 95
681 3
682 46
683 3
684 107
685 1
686 86
687 2
688 158
689 4
690 51
691 1
692 80
694 56
695 4
696 40
698 43
699 3
700 95
701 2
702 51
703 2
704 133
705 1
706 100
707 2
708 121
709 2
710 15
711 3
712 35
713 2
714 20
715 3
716 37
717 2
718 78
720 55
721 1
722 42
723 2
724 218
725 3
726 23
727 2
728 26
729 1
730 64
731 2
732 65
734 24
735 2
736 53
737 1
738 32
739 1
740 60
742 81
743 1
744 77
745 1
746 47
747 1
748 62
749 1
750 19
751 1
752 86
753 3
754 40
756 55
757 2
758 38
759 1
760 101
761 1
762 22
764 67
765 2
766 35
767 1
768 38
769 1
770 22
771 1
772 82
773 1
774 73
776 29
777 1
778 55
780 23
781 1
782 16
784 84
785 3
786 28
788 59
789 1
790 33
791 3
792 24
794 13
795 1
796 110
797 2
798 15
800 22
801 3
802 29
803 1
804 87
806 21
808 29
810 48
812 28
813 1
814 58
815 1
816 48
817 1
818 31
819 1
820 66
822 17
823 2
824 58
826 10
827 2
828 25
829 1
830 29
831 1
832 63
833 1
834 26
835 3
836 52
837 1
838 18
840 27
841 2
842 12
843 1
844 83
845 1
846 7
847 1
848 10
850 26
852 25
853 1
854 15
856 27
858 32
859 1
860 15
862 43
864 32
865 1
866 6
868 39
870 11
872 25
873 1
874 10
875 1
876 20
877 2
878 19
879 1
880 30
882 11
884 53
886 25
887 1
888 28
890 6
892 36
894 10
896 13
898 14
900 31
902 14
903 2
904 43
906 25
908 9
910 11
911 1
912 16
913 1
914 24
916 27
918 6
920 15
922 27
923 1
924 23
926 13
928 42
929 1
930 3
932 27
934 17
936 8
937 1
938 11
940 33
942 4
943 1
944 18
946 15
948 13
950 18
952 12
954 11
956 21
958 10
960 13
962 5
964 32
966 13
968 8
970 8
971 1
972 23
973 2
974 12
975 1
976 22
978 7
979 1
980 14
982 8
984 22
985 1
986 6
988 17
989 1
990 6
992 13
994 19
996 11
998 4
1000 9
1002 2
1004 14
1006 5
1008 3
1010 9
1012 29
1014 6
1016 22
1017 1
1018 8
1019 1
1020 7
1022 6
1023 1
1024 10
1026 2
1028 8
1030 11
1031 2
1032 8
1034 9
1036 13
1038 12
1040 12
1042 3
1044 12
1046 3
1048 11
1050 2
1051 1
1052 2
1054 11
1056 6
1058 8
1059 1
1060 23
1062 6
1063 1
1064 8
1066 3
1068 6
1070 8
1071 1
1072 5
1074 3
1076 5
1078 3
1080 11
1081 1
1082 7
1084 18
1086 4
1087 1
1088 3
1090 3
1092 7
1094 3
1096 12
1098 6
1099 1
1100 2
1102 6
1104 14
1106 3
1108 6
1110 5
1112 2
1114 8
1116 3
1118 3
1120 7
1122 10
1124 6
1126 8
1128 1
1130 4
1132 3
1134 2
1136 5
1138 5
1140 8
1142 3
1144 7
1146 3
1148 11
1150 1
1152 5
1154 1
1156 5
1158 1
1160 5
1162 3
1164 6
1165 1
1166 1
1168 4
1169 1
1170 3
1171 1
1172 2
1174 5
1176 3
1177 1
1180 8
1182 2
1184 4
1186 2
1188 3
1190 2
1192 5
1194 6
1196 1
1198 2
1200 2
1204 10
1206 2
1208 9
1210 1
1214 6
1216 3
1218 4
1220 9
1221 2
1222 1
1224 5
1226 4
1228 8
1230 1
1232 1
1234 3
1236 5
1240 3
1242 1
1244 3
1245 1
1246 4
1248 6
1250 2
1252 7
1256 3
1258 2
1260 2
1262 3
1264 4
1265 1
1266 1
1270 1
1271 1
1272 2
1274 3
1276 3
1278 1
1280 3
1284 1
1286 1
1290 1
1292 3
1294 1
1296 7
1300 2
1302 4
1304 3
1306 2
1308 2
1312 1
1314 1
1316 3
1318 2
1320 1
1324 8
1326 1
1330 1
1331 1
1336 2
1338 1
1340 3
1341 1
1344 1
1346 2
1347 1
1348 3
1352 1
1354 2
1356 1
1358 1
1360 3
1362 1
1364 4
1366 1
1370 1
1372 3
1380 2
1384 2
1388 2
1390 2
1392 2
1394 1
1396 1
1398 1
1400 2
1402 1
1404 1
1406 1
1410 1
1412 5
1418 1
1420 1
1424 1
1432 2
1434 2
1442 3
1444 5
1448 1
1454 1
1456 1
1460 3
1462 4
1468 1
1474 1
1476 1
1478 2
1480 1
1486 2
1488 1
1492 1
1496 1
1500 3
1503 1
1506 1
1512 2
1516 1
1522 1
1524 2
1534 4
1536 1
1538 1
1540 2
1544 2
1548 1
1556 1
1560 1
1562 1
1564 2
1566 1
1568 1
1570 1
1572 1
1576 1
1590 1
1594 1
1604 1
1608 1
1614 1
1622 1
1624 2
1628 1
1629 1
1636 1
1642 1
1654 2
1660 1
1664 1
1670 1
1684 4
1698 1
1732 3
1742 1
1752 1
1760 1
1764 1
1772 2
1798 1
1808 1
1820 1
1852 1
1856 1
1874 1
1902 1
1908 1
1952 1
2004 1
2018 1
2020 1
2028 1
2174 1
2233 1
2244 1
2280 1
2290 1
2352 1
2604 1
4190 1
rowspan="10"
\ No newline at end of file
......@@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
old_lines = [line.strip() for line in lines]
lines = ["O"]
for line in old_lines:
# "O" has already been in lines
if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
continue
lines.append(line)
labels = ["O"]
for line in lines[1:]:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
return label2id_map, id2label_map
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
......@@ -19,7 +20,7 @@ from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
font_size=18):
font_size=14):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
......@@ -40,9 +41,15 @@ def draw_ser_results(image,
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
if "bbox" in ocr_info:
# draw with ocr engine
bbox = ocr_info["bbox"]
else:
# draw with ocr groundtruth
bbox = trans_poly_to_bbox(ocr_info["points"])
draw_box_txt(bbox, text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
......@@ -62,6 +69,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
......@@ -80,10 +95,10 @@ def draw_re_results(image,
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
draw, font, font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
draw, font, font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
......@@ -96,3 +111,16 @@ def draw_re_results(image,
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
......@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
执行预测:
......
......@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
```
Download the pretrained model and predict the result:
......
# PP-Structure 系列模型列表
- [1. 版面分析模型](#1)
- [2. OCR和表格识别模型](#2)
- [2.1 OCR](#21)
- [2.2 表格识别模型](#22)
- [3. VQA模型](#3)
- [4. KIE模型](#4)
- [1. 版面分析模型](#1-版面分析模型)
- [2. OCR和表格识别模型](#2-ocr和表格识别模型)
- [2.1 OCR](#21-ocr)
- [2.2 表格识别模型](#22-表格识别模型)
- [3. VQA模型](#3-vqa模型)
- [4. KIE模型](#4-kie模型)
<a name="1"></a>
......@@ -35,18 +35,18 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a>
## 3. VQA模型
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a>
## 4. KIE模型
......
# PP-Structure Model list
- [1. Layout Analysis](#1)
- [2. OCR and Table Recognition](#2)
- [2.1 OCR](#21)
- [2.2 Table Recognition](#22)
- [3. VQA](#3)
- [4. KIE](#4)
- [1. Layout Analysis](#1-layout-analysis)
- [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
- [2.1 OCR](#21-ocr)
- [2.2 Table Recognition](#22-table-recognition)
- [3. VQA](#3-vqa)
- [4. KIE](#4-kie)
<a name="1"></a>
......@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a>
## 4. KIE
......
......@@ -34,7 +34,7 @@ from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.docx import convert_info_docx
from ppstructure.recovery.recovery_to_doc import convert_info_docx
logger = get_logger()
......
......@@ -44,6 +44,12 @@ python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simp
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
* **(2)安装依赖**
```bash
python3 -m pip install -r ppstructure/recovery/requirements.txt
```
<a name="2.2"></a>
### 2.2 安装PaddleOCR
......
......@@ -23,43 +23,63 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import json
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.visual import draw_rectangle
from ppstructure.utility import parse_args
logger = get_logger()
def build_pre_process_list(args):
resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
pad_op = {
'PaddingTableImage': {
'size': [args.table_max_len, args.table_max_len]
}
}
normalize_op = {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'mean': [0.485, 0.456, 0.406] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'scale': '1./255.',
'order': 'hwc'
}
}
to_chw_op = {'ToCHWImage': None}
keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
if args.table_algorithm not in ['TableMaster']:
pre_process_list = [
resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
]
else:
pre_process_list = [
resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
]
return pre_process_list
class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.table_max_len
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
pre_process_list = build_pre_process_list(args)
if args.table_algorithm not in ['TableMaster']:
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
}
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
else:
postprocess_params = {
'name': 'TableMasterLabelDecode',
"character_dict_path": args.table_char_dict_path,
'box_shape': 'pad'
}
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
......@@ -88,27 +108,17 @@ class TableStructurer(object):
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds)
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc']
imgh, imgw = ori_im.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
structure_str_list = post_result['structure_batch_list'][0]
bbox_list = post_result['bbox_batch_list'][0]
structure_str_list = structure_str_list[0]
structure_str_list = [
'<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
return (structure_str_list, bbox_list), elapse
def main(args):
......@@ -116,21 +126,35 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
structure_str_list, bbox_list = structure_res
bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
......
......@@ -129,11 +129,25 @@ class TableSystem(object):
def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
y1 = pred_bboxes[:,1::2].min()
new_dt_boxes = []
new_rec_res = []
for box,rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
def match_result(self, dt_boxes, pred_bboxes):
matched = {}
for i, gt_box in enumerate(dt_boxes):
......
......@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str)
parser.add_argument(
"--table_char_dict_path",
......@@ -40,6 +41,13 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
parser.add_argument(
"--ser_dict_path",
type=str,
default="../train_data/XFUND/class_list_xfun.txt")
# params for inference
parser.add_argument(
"--mode",
......@@ -65,7 +73,7 @@ def init_args():
"--recovery",
type=bool,
default=False,
help='Whether to enable layout of recovery')
help='Whether to enable layout of recovery')
return parser
......
English | [简体中文](README_ch.md)
- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering)
- [1. Introduction](#1-Introduction)
- [2. Performance](#2-performance)
- [3. Effect demo](#3-Effect-demo)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. Install](#4-Install)
- [4.1 Installation dependencies](#41-Install-dependencies)
- [4.2 Install PaddleOCR](#42-Install-PaddleOCR)
- [5. Usage](#5-Usage)
- [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. Reference](#6-Reference-Links)
- [1 Introduction](#1-introduction)
- [2. Performance](#2-performance)
- [3. Effect demo](#3-effect-demo)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. Install](#4-install)
- [4.1 Install dependencies](#41-install-dependencies)
- [5.3 RE](#53-re)
- [6. Reference Links](#6-reference-links)
- [License](#license)
# Document Visual Question Answering
......@@ -125,13 +121,13 @@ If you want to experience the prediction process directly, you can download the
* Download the processed dataset
The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar).
The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
````
* Convert the dataset
......@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* Use `OCR engine + SER` tandem prediction
* `OCR + SER` tandem prediction based on training engine
Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example:
Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* End-to-end evaluation of `OCR engine + SER` prediction system
* End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
......@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
````
* export model
Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
* `OCR + SER` tandem prediction based on prediction engine
Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
<a name="53"></a>
### 5.3 RE
......@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* export model
cooming soon
* `OCR + SER + RE` tandem prediction based on prediction engine
cooming soon
## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
[English](README.md) | 简体中文
- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa)
- [1. 简介](#1-简介)
- [2. 性能](#2-性能)
- [3. 效果演示](#3-效果演示)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. 安装](#4-安装)
- [4.1 安装依赖](#41-安装依赖)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [5. 使用](#5-使用)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. 参考链接](#6-参考链接)
- [1. 简介](#1-简介)
- [2. 性能](#2-性能)
- [3. 效果演示](#3-效果演示)
- [3.1 SER](#31-ser)
- [3.2 RE](#32-re)
- [4. 安装](#4-安装)
- [4.1 安装依赖](#41-安装依赖)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [5. 使用](#5-使用)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. 参考链接](#6-参考链接)
- [License](#license)
# 文档视觉问答(DOC-VQA)
......@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
* 下载处理好的数据集
处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)
下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
```
* 转换数据集
......@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER`串联预测
* 基于训练引擎的`OCR + SER`串联预测
使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例:
使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估
*`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
......@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
* 模型导出
使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
* 基于预测引擎的`OCR + SER`串联预测
使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE
......@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER + RE`串联预测
* 基于训练引擎的`OCR + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例:
使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell
export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
* 模型导出
cooming soon
* 基于预测引擎的`OCR + SER + RE`串联预测
cooming soon
## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from paddleocr import PaddleOCR
logger = get_logger()
class SerPredictor(object):
def __init__(self, args):
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
pre_process_list = [{
'VQATokenLabelEncode': {
'algorithm': args.vqa_algorithm,
'class_path': args.ser_dict_path,
'contains_re': False,
'ocr_engine': self.ocr_engine
}
}, {
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
'std': [58.395, 57.12, 57.375],
'mean': [123.675, 116.28, 103.53],
'scale': '1',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
}]
postprocess_params = {
'name': 'VQASerTokenLayoutLMPostProcess',
"class_path": args.ser_dict_path,
}
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
for idx in range(len(self.input_tensor)):
expand_input = np.expand_dims(data[idx], axis=0)
self.input_tensor[idx].copy_from_cpu(expand_input)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
post_result = self.postprocess_op(
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
elapse = time.time() - starttime
return post_result, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerPredictor(args)
count = 0
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_ser_results(
image_file,
ser_res,
font_path="../doc/fonts/simfang.ttf", )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
sentencepiece
yacs
seqeval
paddlenlp>=2.2.1
\ No newline at end of file
paddlenlp>=2.2.1
pypandoc
attrdict
python_docx
\ No newline at end of file
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import cv2
import numpy as np
from copy import deepcopy
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def get_outer_poly(bbox_list):
x1 = min([bbox[0] for bbox in bbox_list])
y1 = min([bbox[1] for bbox in bbox_list])
x2 = max([bbox[2] for bbox in bbox_list])
y2 = max([bbox[3] for bbox in bbox_list])
return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def load_funsd_label(image_dir, anno_dir):
imgs = os.listdir(image_dir)
annos = os.listdir(anno_dir)
imgs = [img.replace(".png", "") for img in imgs]
annos = [anno.replace(".json", "") for anno in annos]
fn_info_map = dict()
for anno_fn in annos:
res = []
with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin:
infos = json.load(fin)
infos = infos["form"]
old_id2new_id_map = dict()
global_new_id = 0
for info in infos:
if info["text"] is None:
continue
words = info["words"]
if len(words) <= 0:
continue
word_idx = 1
curr_bboxes = [words[0]["box"]]
curr_texts = [words[0]["text"]]
while word_idx < len(words):
# switch to a new link
if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][
"box"][2]:
if len("".join(curr_texts[0])) > 0:
res.append({
"transcription": " ".join(curr_texts),
"label": info["label"],
"points": get_outer_poly(curr_bboxes),
"linking": info["linking"],
"id": global_new_id,
})
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
global_new_id += 1
curr_bboxes = [words[word_idx]["box"]]
curr_texts = [words[word_idx]["text"]]
else:
curr_bboxes.append(words[word_idx]["box"])
curr_texts.append(words[word_idx]["text"])
word_idx += 1
if len("".join(curr_texts[0])) > 0:
res.append({
"transcription": " ".join(curr_texts),
"label": info["label"],
"points": get_outer_poly(curr_bboxes),
"linking": info["linking"],
"id": global_new_id,
})
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
global_new_id += 1
res = sorted(
res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
for i in range(len(res) - 1):
for j in range(i, 0, -1):
if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \
(res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
tmp = deepcopy(res[j])
res[j] = deepcopy(res[j + 1])
res[j + 1] = deepcopy(tmp)
else:
break
# re-generate unique ids
for idx, r in enumerate(res):
new_links = []
for link in r["linking"]:
# illegal links will be removed
if link[0] not in old_id2new_id_map or link[
1] not in old_id2new_id_map:
continue
for src in old_id2new_id_map[link[0]]:
for dst in old_id2new_id_map[link[1]]:
new_links.append([src, dst])
res[idx]["linking"] = deepcopy(new_links)
fn_info_map[anno_fn] = res
return fn_info_map
def main():
test_image_dir = "train_data/FUNSD/testing_data/images/"
test_anno_dir = "train_data/FUNSD/testing_data/annotations/"
test_output_dir = "train_data/FUNSD/test.json"
fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
with open(test_output_dir, "w") as fout:
for fn in fn_info_map:
fout.write(fn + ".png" + "\t" + json.dumps(
fn_info_map[fn], ensure_ascii=False) + "\n")
train_image_dir = "train_data/FUNSD/training_data/images/"
train_anno_dir = "train_data/FUNSD/training_data/annotations/"
train_output_dir = "train_data/FUNSD/train.json"
fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
with open(train_output_dir, "w") as fout:
for fn in fn_info_map:
fout.write(fn + ".png" + "\t" + json.dumps(
fn_info_map[fn], ensure_ascii=False) + "\n")
print("====ok====")
return
if __name__ == "__main__":
main()
......@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
json_info = json.loads(lines[0])
documents = json_info["documents"]
label_info = {}
with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents):
label_info = []
img_info = document["img"]
document = document["document"]
image_path = img_info["fname"]
label_info["height"] = img_info["height"]
label_info["width"] = img_info["width"]
label_info["ocr_info"] = []
for doc in document:
label_info["ocr_info"].append({
"text": doc["text"],
x1, y1, x2, y2 = doc["box"]
points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
label_info.append({
"transcription": doc["text"],
"label": doc["label"],
"bbox": doc["box"],
"points": points,
"id": doc["id"],
"linking": doc["linking"],
"words": doc["words"]
"linking": doc["linking"]
})
fout.write(image_path + "\t" + json.dumps(
......
......@@ -21,6 +21,18 @@ function func_parser_params(){
echo ${tmp}
}
function set_dynamic_epoch(){
string=$1
num=$2
_str=${string:1:6}
IFS="C"
arr=(${_str})
M=${arr[0]}
P=${arr[1]}
ep=`expr $num \* $M \* $P`
echo $ep
}
function func_sed_params(){
filename=$1
line=$2
......@@ -139,10 +151,11 @@ else
device_num=${params_list[4]}
IFS=";"
if [ ${precision} = "null" ];then
precision="fp32"
if [ ${precision} = "fp16" ];then
precision="amp"
fi
epoch=$(set_dynamic_epoch $device_num $epoch)
fp_items_list=($precision)
batch_size_list=($batch_size)
device_num_list=($device_num)
......@@ -150,10 +163,16 @@ fi
IFS="|"
for batch_size in ${batch_size_list[*]}; do
for precision in ${fp_items_list[*]}; do
for train_precision in ${fp_items_list[*]}; do
for device_num in ${device_num_list[*]}; do
# sed batchsize and precision
func_sed_params "$FILENAME" "${line_precision}" "$precision"
if [ ${train_precision} = "amp" ];then
precision="fp16"
else
precision="fp32"
fi
func_sed_params "$FILENAME" "${line_precision}" "$train_precision"
func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size"
func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch"
gpu_id=$(set_gpu_id $device_num)
......
......@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
epoch:2
epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
......@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8|16
fp_items:fp32|fp16
epoch:2
epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
===========================train_params===========================
model_name:det_r50_db_plusplus
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_db++_train/best_accuracy
infer_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
infer_quant:False
inference:tools/infer/predict_det.py --det_algorithm="DB++"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_params===========================
model_name:det_r50_vd_east_v2_0
python:python3.7
gpu_list:0
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
Global.pretrained_model:./pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
......
......@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
epoch:10
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
......@@ -9,16 +9,15 @@ Global:
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
checkpoints:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/table/table.jpg
infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
infer_mode: False
process_total_num: 0
process_cut_num: 0
......@@ -44,11 +43,8 @@ Architecture:
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
Loss:
name: TableAttentionLoss
......@@ -61,28 +57,34 @@ PostProcess:
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/train
label_file_path: ./train_data/pubtabnet/train.jsonl
label_file_list: [./train_data/pubtabnet/train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 32
......@@ -93,23 +95,28 @@ Eval:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/test/
label_file_path: ./train_data/pubtabnet/test.jsonl
label_file_list: [./train_data/pubtabnet/test.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
......
......@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=2
Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
......
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/train
label_file_list: [./train_data/pubtabnet/train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/test/
label_file_list: [./train_data/pubtabnet/test.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
===========================train_params===========================
model_name:table_master
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:./pretrain_models/table_structure_tablemaster_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=10
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/table_master/table_master.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --output ./output/table --table_algorithm=TableMaster --table_max_len=480
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--table_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,480,480]}]
......@@ -9,7 +9,7 @@
```shell
# 运行格式:bash test_tipc/prepare.sh train_benchmark.txt mode
bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train
bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt benchmark_train
```
## 1.2 功能测试
......@@ -33,7 +33,7 @@ dynamic_bs8_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式
## 2. 日志输出
运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt` 参数文件的训练日志解析结果是:
运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt` 参数文件的训练日志解析结果是:
```
{"model_branch": "dygaph", "model_commit": "7c39a1996b19087737c05d883fd346d2f39dbcc0", "model_name": "det_mv3_db_v2_0_bs8_fp32_SingleP_DP", "batch_size": 8, "fp_item": "fp32", "run_process_type": "SingleP", "run_mode": "DP", "convergence_value": "5.413110", "convergence_key": "loss:", "ips": 19.333, "speed_unit": "samples/s", "device_num": "N1C1", "model_run_time": "0", "frame_commit": "8cc09552473b842c651ead3b9848d41827a3dbab", "frame_version": "0.0.0"}
......
......@@ -22,13 +22,19 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
fi
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
fi
if [[ ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
......@@ -52,13 +58,20 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv3_det_distill_train.tar && cd ../
fi
if [ ${model_name} == "en_table_structure" ];then
if [ ${model_name} == "en_table_structure" ] || [ ${model_name} == "en_table_structure_PACT" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi
if [[ ${model_name} =~ "det_r50_db_plusplus" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
fi
if [ ${model_name} == "table_master" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf table_structure_tablemaster_train.tar && cd ../
fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
rm -rf ./train_data/ic15_data
......@@ -120,6 +133,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
if [ ${model_name} == "det_r50_vd_east_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
......
......@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
......
......@@ -139,7 +139,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}_klquant"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_log_path="${LOG_PATH}/_export_${Count}.log"
export_log_path="${LOG_PATH}/${MODE}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
......
......@@ -87,11 +87,12 @@ function func_serving(){
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list})
cd ${serving_dir_value}
# cpp serving
for gpu_id in ${gpu_value[*]}; do
if [ ${gpu_id} = "null" ]; then
server_log_path="${LOG_PATH}/cpp_server_cpu.log"
web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 "
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
......@@ -105,7 +106,7 @@ function func_serving(){
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else
server_log_path="${LOG_PATH}/cpp_server_gpu.log"
web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > ${server_log_path} 2>&1 "
web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd
sleep 5s
_save_log_path="${LOG_PATH}/cpp_client_gpu.log"
......
......@@ -112,7 +112,7 @@ function func_serving(){
cd ${serving_dir_value}
python=${python_list[0]}
# python serving
for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then
......@@ -123,19 +123,19 @@ function func_serving(){
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
......@@ -174,19 +174,19 @@ function func_serving(){
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 "
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd
last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
......
......@@ -193,7 +193,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_log_path="${LOG_PATH}/_export_${Count}.log"
export_log_path="${LOG_PATH}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]}
echo $export_cmd
......@@ -265,7 +265,7 @@ else
if [ ${run_train} = "null" ]; then
continue
fi
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
......@@ -287,14 +287,15 @@ else
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
elif [ ${#ips} -le 15 ];then # train with multi-gpu
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
else # train with multi-machine
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
fi
# run train
eval $cmd
eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
......
......@@ -97,6 +97,22 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image
]
if arch_config["algorithm"] == "LayoutLM":
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
......@@ -110,6 +126,8 @@ def export_single_model(model,
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
model = to_static(
model,
input_spec=[
......@@ -172,7 +190,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
load_model(config, model)
load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
......
......@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
'mean':
[0.48109378172549, 0.45752457890196, 0.40787054090196],
'scale': '1./255.',
'order': 'hwc'
}
}
elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh
......@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']:
elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs):
......
......@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
else:
model_dir = args.e2e_model_dir
......@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
for name in input_names:
input_tensor = predictor.get_input_handle(name)
if mode in ['ser', 're']:
input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
else:
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
......
......@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
import cv2
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path):
os.makedirs(save_path)
......@@ -106,7 +106,7 @@ def main():
dt_boxes_list = []
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
tmp_json['points'] = np.array(box).tolist()
dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][
......@@ -118,7 +118,7 @@ def main():
# write result
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
tmp_json['points'] = np.array(box).tolist()
dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
......
......@@ -39,13 +39,12 @@ import time
def read_class_list(filepath):
dict = {}
ret = {}
with open(filepath, "r") as f:
lines = f.readlines()
for line in lines:
key, value = line.split(" ")
dict[key] = value.rstrip()
return dict
for idx, line in enumerate(lines):
ret[idx] = line.strip("\n")
return ret
def draw_kie_result(batch, node, idx_to_cls, count):
......@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box]))
pred_label = str(node_pred_label[i])
pred_label = node_pred_label[i]
if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i])
......@@ -109,8 +108,7 @@ def main():
save_res_path = config['Global']['save_res_path']
class_path = config['Global']['class_path']
idx_to_cls = read_class_list(class_path)
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
model.eval()
......
......@@ -36,10 +36,12 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
from ppocr.utils.visual import draw_rectangle
import tools.program as program
import cv2
@paddle.no_grad()
def main(config, device, logger, vdl_writer):
global_config = config['Global']
......@@ -53,53 +55,61 @@ def main(config, device, logger, vdl_writer):
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model)
# create data ops
transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
if 'Encode' in op_name:
continue
if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
os.makedirs(save_res_path, exist_ok=True)
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds)
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc']
img = cv2.imread(file)
imgh, imgw = img.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
res_loc_final.append([left, top, right, bottom])
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
logger.info("success!")
with open(
os.path.join(save_res_path, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, [shape_list])
structure_str_list = post_result['structure_batch_list'][0]
bbox_list = post_result['bbox_batch_list'][0]
structure_str_list = structure_str_list[0]
structure_str_list = [
'<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>']
bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(file, bbox_list, use_xywh)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!")
if __name__ == '__main__':
......
......@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
......@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object):
def __init__(self, config):
global_config = config['Global']
self.algorithm = config['Architecture']["algorithm"]
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
......@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
......@@ -80,29 +85,30 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
'token_type_ids', 'segment_offset_id', 'ocr_info',
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
transforms.append(op)
global_config['infer_mode'] = True
if config["Global"].get("infer_mode", None) is None:
global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config)
self.model.eval()
def __call__(self, img_path):
with open(img_path, 'rb') as f:
def __call__(self, data):
with open(data["img_path"], 'rb') as f:
img = f.read()
data = {'image': img}
data["image"] = img
batch = transform(data, self.ops)
batch = to_tensor(batch)
preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class(
preds,
attention_masks=batch[4],
segment_offset_ids=batch[6],
ocr_infos=batch[7])
preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
......@@ -112,20 +118,33 @@ if __name__ == '__main__':
ser_engine = SerPredictor(config)
infer_imgs = get_image_file_list(config['Global']['infer_img'])
if config["Global"].get("infer_mode", None) is False:
data_dir = config['Eval']['dataset']['data_dir']
with open(config['Global']['infer_img'], "rb") as f:
infer_imgs = f.readlines()
else:
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
data_line = info.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
data = {'img_path': img_path, 'label': substr[1]}
else:
img_path = info
data = {'img_path': img_path}
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(img_path)
result, _ = ser_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
......@@ -133,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
......@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config, check_gpu
from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor
......@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
ser_inputs.pop(1)
ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch
......@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval()
def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine(img_path)
paddle.save(ser_inputs, 'ser_inputs.npy')
paddle.save(ser_results, 'ser_results.npy')
ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
......@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
......@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path)
result = result[0]
......@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
......@@ -281,8 +281,11 @@ def train(config,
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
if model_type in ['kie']:
eval_class(preds, batch)
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
......@@ -463,7 +466,6 @@ def eval(model,
preds = model(batch)
else:
preds = model(images)
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
......@@ -473,9 +475,9 @@ def eval(model,
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
if model_type in ['table', 'kie']:
if model_type in ['kie']:
eval_class(preds, batch_numpy)
elif model_type in ['vqa']:
elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
else:
......@@ -576,8 +578,8 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR',
'ViTSTR', 'ABINet'
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册