提交 e34f4456 编写于 作者: A andyjpaddle

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

...@@ -10,7 +10,8 @@ __pycache__/ ...@@ -10,7 +10,8 @@ __pycache__/
inference/ inference/
inference_results/ inference_results/
output/ output/
train_data/
log/
*.DS_Store *.DS_Store
*.vs *.vs
*.user *.user
......
...@@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin ...@@ -28,7 +28,7 @@ from PyQt5.QtCore import QSize, Qt, QPoint, QByteArray, QTimer, QFileInfo, QPoin
from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader from PyQt5.QtGui import QImage, QCursor, QPixmap, QImageReader
from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \ from PyQt5.QtWidgets import QMainWindow, QListWidget, QVBoxLayout, QToolButton, QHBoxLayout, QDockWidget, QWidget, \
QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \ QSlider, QGraphicsOpacityEffect, QMessageBox, QListView, QScrollArea, QWidgetAction, QApplication, QLabel, QGridLayout, \
QFileDialog, QListWidgetItem, QComboBox, QDialog QFileDialog, QListWidgetItem, QComboBox, QDialog, QAbstractItemView, QSizePolicy
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
...@@ -227,6 +227,21 @@ class MainWindow(QMainWindow): ...@@ -227,6 +227,21 @@ class MainWindow(QMainWindow):
listLayout.addWidget(leftTopToolBoxContainer) listLayout.addWidget(leftTopToolBoxContainer)
# ================== Label List ================== # ================== Label List ==================
labelIndexListlBox = QHBoxLayout()
# Create and add a widget for showing current label item index
self.indexList = QListWidget()
self.indexList.setMaximumSize(30, 16777215) # limit max width
self.indexList.setEditTriggers(QAbstractItemView.NoEditTriggers) # no editable
self.indexList.itemSelectionChanged.connect(self.indexSelectionChanged)
self.indexList.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) # no scroll Bar
self.indexListDock = QDockWidget('No.', self)
self.indexListDock.setWidget(self.indexList)
self.indexListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
labelIndexListlBox.addWidget(self.indexListDock, 1)
# no margin between two boxes
labelIndexListlBox.setSpacing(0)
# Create and add a widget for showing current label items # Create and add a widget for showing current label items
self.labelList = EditInList() self.labelList = EditInList()
labelListContainer = QWidget() labelListContainer = QWidget()
...@@ -240,7 +255,32 @@ class MainWindow(QMainWindow): ...@@ -240,7 +255,32 @@ class MainWindow(QMainWindow):
self.labelListDock = QDockWidget(self.labelListDockName, self) self.labelListDock = QDockWidget(self.labelListDockName, self)
self.labelListDock.setWidget(self.labelList) self.labelListDock.setWidget(self.labelList)
self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures) self.labelListDock.setFeatures(QDockWidget.NoDockWidgetFeatures)
listLayout.addWidget(self.labelListDock) labelIndexListlBox.addWidget(self.labelListDock, 10) # label list is wider than index list
# enable labelList drag_drop to adjust bbox order
# 设置选择模式为单选
self.labelList.setSelectionMode(QAbstractItemView.SingleSelection)
# 启用拖拽
self.labelList.setDragEnabled(True)
# 设置接受拖放
self.labelList.viewport().setAcceptDrops(True)
# 设置显示将要被放置的位置
self.labelList.setDropIndicatorShown(True)
# 设置拖放模式为移动项目,如果不设置,默认为复制项目
self.labelList.setDragDropMode(QAbstractItemView.InternalMove)
# 触发放置
self.labelList.model().rowsMoved.connect(self.drag_drop_happened)
labelIndexListContainer = QWidget()
labelIndexListContainer.setLayout(labelIndexListlBox)
listLayout.addWidget(labelIndexListContainer)
# labelList indexList同步滚动
self.labelListBar = self.labelList.verticalScrollBar()
self.indexListBar = self.indexList.verticalScrollBar()
self.labelListBar.valueChanged.connect(self.move_scrollbar)
self.indexListBar.valueChanged.connect(self.move_scrollbar)
# ================== Detection Box ================== # ================== Detection Box ==================
self.BoxList = QListWidget() self.BoxList = QListWidget()
...@@ -589,15 +629,23 @@ class MainWindow(QMainWindow): ...@@ -589,15 +629,23 @@ class MainWindow(QMainWindow):
self.displayLabelOption.setChecked(settings.get(SETTING_PAINT_LABEL, False)) self.displayLabelOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayLabelOption.triggered.connect(self.togglePaintLabelsOption) self.displayLabelOption.triggered.connect(self.togglePaintLabelsOption)
# Add option to enable/disable box index being displayed at the top of bounding boxes
self.displayIndexOption = QAction(getStr('displayIndex'), self)
self.displayIndexOption.setCheckable(True)
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.displayIndexOption.triggered.connect(self.togglePaintIndexOption)
self.labelDialogOption = QAction(getStr('labelDialogOption'), self) self.labelDialogOption = QAction(getStr('labelDialogOption'), self)
self.labelDialogOption.setShortcut("Ctrl+Shift+L") self.labelDialogOption.setShortcut("Ctrl+Shift+L")
self.labelDialogOption.setCheckable(True) self.labelDialogOption.setCheckable(True)
self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False)) self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.labelDialogOption.triggered.connect(self.speedChoose) self.labelDialogOption.triggered.connect(self.speedChoose)
self.autoSaveOption = QAction(getStr('autoSaveMode'), self) self.autoSaveOption = QAction(getStr('autoSaveMode'), self)
self.autoSaveOption.setCheckable(True) self.autoSaveOption.setCheckable(True)
self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False)) self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
self.displayIndexOption.setChecked(settings.get(SETTING_PAINT_INDEX, False))
self.autoSaveOption.triggered.connect(self.autoSaveFunc) self.autoSaveOption.triggered.connect(self.autoSaveFunc)
addActions(self.menus.file, addActions(self.menus.file,
...@@ -606,7 +654,7 @@ class MainWindow(QMainWindow): ...@@ -606,7 +654,7 @@ class MainWindow(QMainWindow):
addActions(self.menus.help, (showKeys, showSteps, showInfo)) addActions(self.menus.help, (showKeys, showSteps, showInfo))
addActions(self.menus.view, ( addActions(self.menus.view, (
self.displayLabelOption, self.labelDialogOption, self.displayLabelOption, self.displayIndexOption, self.labelDialogOption,
None, None,
hideAll, showAll, None, hideAll, showAll, None,
zoomIn, zoomOut, zoomOrg, None, zoomIn, zoomOut, zoomOrg, None,
...@@ -744,6 +792,7 @@ class MainWindow(QMainWindow): ...@@ -744,6 +792,7 @@ class MainWindow(QMainWindow):
self.shapesToItemsbox.clear() self.shapesToItemsbox.clear()
self.labelList.clear() self.labelList.clear()
self.BoxList.clear() self.BoxList.clear()
self.indexList.clear()
self.filePath = None self.filePath = None
self.imageData = None self.imageData = None
self.labelFile = None self.labelFile = None
...@@ -964,9 +1013,10 @@ class MainWindow(QMainWindow): ...@@ -964,9 +1013,10 @@ class MainWindow(QMainWindow):
else: else:
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes self.canvas.selectedShapes_hShape = self.canvas.selectedShapes
for shape in self.canvas.selectedShapes_hShape: for shape in self.canvas.selectedShapes_hShape:
item = self.shapesToItemsbox[shape] # listitem if shape in self.shapesToItemsbox.keys():
text = [(int(p.x()), int(p.y())) for p in shape.points] item = self.shapesToItemsbox[shape] # listitem
item.setText(str(text)) text = [(int(p.x()), int(p.y())) for p in shape.points]
item.setText(str(text))
self.actions.undo.setEnabled(True) self.actions.undo.setEnabled(True)
self.setDirty() self.setDirty()
...@@ -1004,13 +1054,19 @@ class MainWindow(QMainWindow): ...@@ -1004,13 +1054,19 @@ class MainWindow(QMainWindow):
for shape in self.canvas.selectedShapes: for shape in self.canvas.selectedShapes:
shape.selected = False shape.selected = False
self.labelList.clearSelection() self.labelList.clearSelection()
self.indexList.clearSelection()
self.canvas.selectedShapes = selected_shapes self.canvas.selectedShapes = selected_shapes
for shape in self.canvas.selectedShapes: for shape in self.canvas.selectedShapes:
shape.selected = True shape.selected = True
self.shapesToItems[shape].setSelected(True) self.shapesToItems[shape].setSelected(True)
self.shapesToItemsbox[shape].setSelected(True) self.shapesToItemsbox[shape].setSelected(True)
index = self.labelList.indexFromItem(self.shapesToItems[shape]).row()
self.indexList.item(index).setSelected(True)
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
# map current label item to index item and select it
index = self.labelList.indexFromItem(self.currentItem()).row()
self.indexList.scrollToItem(self.indexList.item(index))
self.BoxList.scrollToItem(self.currentBox()) self.BoxList.scrollToItem(self.currentBox())
if self.kie_mode: if self.kie_mode:
...@@ -1040,13 +1096,21 @@ class MainWindow(QMainWindow): ...@@ -1040,13 +1096,21 @@ class MainWindow(QMainWindow):
def addLabel(self, shape): def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked() shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
item = HashableQListWidgetItem(shape.label) item = HashableQListWidgetItem(shape.label)
item.setFlags(item.flags() | Qt.ItemIsUserCheckable) # current difficult checkbox is disenble
item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked) # item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
# item.setCheckState(Qt.Unchecked) if shape.difficult else item.setCheckState(Qt.Checked)
# Checked means difficult is False # Checked means difficult is False
# item.setBackground(generateColorByText(shape.label)) # item.setBackground(generateColorByText(shape.label))
self.itemsToShapes[item] = shape self.itemsToShapes[item] = shape
self.shapesToItems[shape] = item self.shapesToItems[shape] = item
# add current label item index before label string
current_index = QListWidgetItem(str(self.labelList.count()))
current_index.setTextAlignment(Qt.AlignHCenter)
self.indexList.addItem(current_index)
self.labelList.addItem(item) self.labelList.addItem(item)
# print('item in add label is ',[(p.x(), p.y()) for p in shape.points], shape.label) # print('item in add label is ',[(p.x(), p.y()) for p in shape.points], shape.label)
...@@ -1080,9 +1144,11 @@ class MainWindow(QMainWindow): ...@@ -1080,9 +1144,11 @@ class MainWindow(QMainWindow):
del self.shapesToItemsbox[shape] del self.shapesToItemsbox[shape]
del self.itemsToShapesbox[item] del self.itemsToShapesbox[item]
self.updateComboBox() self.updateComboBox()
self.updateIndexList()
def loadLabels(self, shapes): def loadLabels(self, shapes):
s = [] s = []
shape_index = 0
for label, points, line_color, key_cls, difficult in shapes: for label, points, line_color, key_cls, difficult in shapes:
shape = Shape(label=label, line_color=line_color, key_cls=key_cls) shape = Shape(label=label, line_color=line_color, key_cls=key_cls)
for x, y in points: for x, y in points:
...@@ -1094,6 +1160,8 @@ class MainWindow(QMainWindow): ...@@ -1094,6 +1160,8 @@ class MainWindow(QMainWindow):
shape.addPoint(QPointF(x, y)) shape.addPoint(QPointF(x, y))
shape.difficult = difficult shape.difficult = difficult
shape.idx = shape_index
shape_index += 1
# shape.locked = False # shape.locked = False
shape.close() shape.close()
s.append(shape) s.append(shape)
...@@ -1128,6 +1196,13 @@ class MainWindow(QMainWindow): ...@@ -1128,6 +1196,13 @@ class MainWindow(QMainWindow):
# self.comboBox.update_items(uniqueTextList) # self.comboBox.update_items(uniqueTextList)
def updateIndexList(self):
self.indexList.clear()
for i in range(self.labelList.count()):
string = QListWidgetItem(str(i))
string.setTextAlignment(Qt.AlignHCenter)
self.indexList.addItem(string)
def saveLabels(self, annotationFilePath, mode='Auto'): def saveLabels(self, annotationFilePath, mode='Auto'):
# Mode is Auto means that labels will be loaded from self.result_dic totally, which is the output of ocr model # Mode is Auto means that labels will be loaded from self.result_dic totally, which is the output of ocr model
annotationFilePath = ustr(annotationFilePath) annotationFilePath = ustr(annotationFilePath)
...@@ -1183,6 +1258,10 @@ class MainWindow(QMainWindow): ...@@ -1183,6 +1258,10 @@ class MainWindow(QMainWindow):
# fix copy and delete # fix copy and delete
# self.shapeSelectionChanged(True) # self.shapeSelectionChanged(True)
def move_scrollbar(self, value):
self.labelListBar.setValue(value)
self.indexListBar.setValue(value)
def labelSelectionChanged(self): def labelSelectionChanged(self):
if self._noSelectionSlot: if self._noSelectionSlot:
return return
...@@ -1195,6 +1274,21 @@ class MainWindow(QMainWindow): ...@@ -1195,6 +1274,21 @@ class MainWindow(QMainWindow):
else: else:
self.canvas.deSelectShape() self.canvas.deSelectShape()
def indexSelectionChanged(self):
if self._noSelectionSlot:
return
if self.canvas.editing():
selected_shapes = []
for item in self.indexList.selectedItems():
# map index item to label item
index = self.indexList.indexFromItem(item).row()
item = self.labelList.item(index)
selected_shapes.append(self.itemsToShapes[item])
if selected_shapes:
self.canvas.selectShapes(selected_shapes)
else:
self.canvas.deSelectShape()
def boxSelectionChanged(self): def boxSelectionChanged(self):
if self._noSelectionSlot: if self._noSelectionSlot:
# self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter) # self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter)
...@@ -1209,18 +1303,54 @@ class MainWindow(QMainWindow): ...@@ -1209,18 +1303,54 @@ class MainWindow(QMainWindow):
self.canvas.deSelectShape() self.canvas.deSelectShape()
def labelItemChanged(self, item): def labelItemChanged(self, item):
shape = self.itemsToShapes[item] # avoid accidentally triggering the itemChanged siganl with unhashable item
label = item.text() # Unknown trigger condition
if label != shape.label: if type(item) == HashableQListWidgetItem:
shape.label = item.text() shape = self.itemsToShapes[item]
# shape.line_color = generateColorByText(shape.label) label = item.text()
self.setDirty() if label != shape.label:
elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)): shape.label = item.text()
shape.difficult = True if item.checkState() == Qt.Unchecked else False # shape.line_color = generateColorByText(shape.label)
self.setDirty() self.setDirty()
else: # User probably changed item visibility elif not ((item.checkState() == Qt.Unchecked) ^ (not shape.difficult)):
self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked shape.difficult = True if item.checkState() == Qt.Unchecked else False
# self.actions.save.setEnabled(True) self.setDirty()
else: # User probably changed item visibility
self.canvas.setShapeVisible(shape, True) # item.checkState() == Qt.Checked
# self.actions.save.setEnabled(True)
else:
print('enter labelItemChanged slot with unhashable item: ', item, item.text())
def drag_drop_happened(self):
'''
label list drag drop signal slot
'''
# print('___________________drag_drop_happened_______________')
# should only select single item
for item in self.labelList.selectedItems():
newIndex = self.labelList.indexFromItem(item).row()
# only support drag_drop one item
assert len(self.canvas.selectedShapes) > 0
for shape in self.canvas.selectedShapes:
selectedShapeIndex = shape.idx
if newIndex == selectedShapeIndex:
return
# move corresponding item in shape list
shape = self.canvas.shapes.pop(selectedShapeIndex)
self.canvas.shapes.insert(newIndex, shape)
# update bbox index
self.canvas.updateShapeIndex()
# boxList update simultaneously
item = self.BoxList.takeItem(selectedShapeIndex)
self.BoxList.insertItem(newIndex, item)
# changes happen
self.setDirty()
# Callback functions: # Callback functions:
def newShape(self, value=True): def newShape(self, value=True):
...@@ -1453,6 +1583,7 @@ class MainWindow(QMainWindow): ...@@ -1453,6 +1583,7 @@ class MainWindow(QMainWindow):
if self.labelList.count(): if self.labelList.count():
self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1)) self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
self.labelList.item(self.labelList.count() - 1).setSelected(True) self.labelList.item(self.labelList.count() - 1).setSelected(True)
self.indexList.item(self.labelList.count() - 1).setSelected(True)
# show file list image count # show file list image count
select_indexes = self.fileListWidget.selectedIndexes() select_indexes = self.fileListWidget.selectedIndexes()
...@@ -1560,6 +1691,7 @@ class MainWindow(QMainWindow): ...@@ -1560,6 +1691,7 @@ class MainWindow(QMainWindow):
settings[SETTING_LAST_OPEN_DIR] = '' settings[SETTING_LAST_OPEN_DIR] = ''
settings[SETTING_PAINT_LABEL] = self.displayLabelOption.isChecked() settings[SETTING_PAINT_LABEL] = self.displayLabelOption.isChecked()
settings[SETTING_PAINT_INDEX] = self.displayIndexOption.isChecked()
settings[SETTING_DRAW_SQUARE] = self.drawSquaresOption.isChecked() settings[SETTING_DRAW_SQUARE] = self.drawSquaresOption.isChecked()
settings.save() settings.save()
try: try:
...@@ -1946,8 +2078,18 @@ class MainWindow(QMainWindow): ...@@ -1946,8 +2078,18 @@ class MainWindow(QMainWindow):
self.labelHist.append(line) self.labelHist.append(line)
def togglePaintLabelsOption(self): def togglePaintLabelsOption(self):
self.displayIndexOption.setChecked(False)
for shape in self.canvas.shapes: for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked() shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
self.canvas.repaint()
def togglePaintIndexOption(self):
self.displayLabelOption.setChecked(False)
for shape in self.canvas.shapes:
shape.paintLabel = self.displayLabelOption.isChecked()
shape.paintIdx = self.displayIndexOption.isChecked()
self.canvas.repaint()
def toogleDrawSquare(self): def toogleDrawSquare(self):
self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked()) self.canvas.setDrawingShapeToSquare(self.drawSquaresOption.isChecked())
...@@ -2042,7 +2184,7 @@ class MainWindow(QMainWindow): ...@@ -2042,7 +2184,7 @@ class MainWindow(QMainWindow):
self.init_key_list(self.Cachelabel) self.init_key_list(self.Cachelabel)
def reRecognition(self): def reRecognition(self):
img = cv2.imread(self.filePath) img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]] # org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes: if self.canvas.shapes:
self.result_dic = [] self.result_dic = []
...@@ -2111,7 +2253,7 @@ class MainWindow(QMainWindow): ...@@ -2111,7 +2253,7 @@ class MainWindow(QMainWindow):
QMessageBox.information(self, "Information", "Draw a box!") QMessageBox.information(self, "Information", "Draw a box!")
def singleRerecognition(self): def singleRerecognition(self):
img = cv2.imread(self.filePath) img = cv2.imdecode(np.fromfile(self.filePath,dtype=np.uint8),1)
for shape in self.canvas.selectedShapes: for shape in self.canvas.selectedShapes:
box = [[int(p.x()), int(p.y())] for p in shape.points] box = [[int(p.x()), int(p.y())] for p in shape.points]
if len(box) > 4: if len(box) > 4:
...@@ -2181,12 +2323,14 @@ class MainWindow(QMainWindow): ...@@ -2181,12 +2323,14 @@ class MainWindow(QMainWindow):
self.itemsToShapesbox.clear() # ADD self.itemsToShapesbox.clear() # ADD
self.shapesToItemsbox.clear() self.shapesToItemsbox.clear()
self.labelList.clear() self.labelList.clear()
self.indexList.clear()
self.BoxList.clear() self.BoxList.clear()
self.result_dic = [] self.result_dic = []
self.result_dic_locked = [] self.result_dic_locked = []
shapes = [] shapes = []
result_len = len(region['res']['boxes']) result_len = len(region['res']['boxes'])
order_index = 0
for i in range(result_len): for i in range(result_len):
bbox = np.array(region['res']['boxes'][i]) bbox = np.array(region['res']['boxes'][i])
rec_text = region['res']['rec_res'][i][0] rec_text = region['res']['rec_res'][i][0]
...@@ -2205,6 +2349,8 @@ class MainWindow(QMainWindow): ...@@ -2205,6 +2349,8 @@ class MainWindow(QMainWindow):
x, y, snapped = self.canvas.snapPointToCanvas(x, y) x, y, snapped = self.canvas.snapPointToCanvas(x, y)
shape.addPoint(QPointF(x, y)) shape.addPoint(QPointF(x, y))
shape.difficult = False shape.difficult = False
shape.idx = order_index
order_index += 1
# shape.locked = False # shape.locked = False
shape.close() shape.close()
self.addLabel(shape) self.addLabel(shape)
...@@ -2589,6 +2735,7 @@ class MainWindow(QMainWindow): ...@@ -2589,6 +2735,7 @@ class MainWindow(QMainWindow):
def undoShapeEdit(self): def undoShapeEdit(self):
self.canvas.restoreShape() self.canvas.restoreShape()
self.labelList.clear() self.labelList.clear()
self.indexList.clear()
self.BoxList.clear() self.BoxList.clear()
self.loadShapes(self.canvas.shapes) self.loadShapes(self.canvas.shapes)
self.actions.undo.setEnabled(self.canvas.isShapeRestorable) self.actions.undo.setEnabled(self.canvas.isShapeRestorable)
...@@ -2598,6 +2745,7 @@ class MainWindow(QMainWindow): ...@@ -2598,6 +2745,7 @@ class MainWindow(QMainWindow):
for shape in shapes: for shape in shapes:
self.addLabel(shape) self.addLabel(shape)
self.labelList.clearSelection() self.labelList.clearSelection()
self.indexList.clearSelection()
self._noSelectionSlot = False self._noSelectionSlot = False
self.canvas.loadShapes(shapes, replace=replace) self.canvas.loadShapes(shapes, replace=replace)
print("loadShapes") # 1 print("loadShapes") # 1
......
...@@ -2,7 +2,7 @@ English | [简体中文](README_ch.md) ...@@ -2,7 +2,7 @@ English | [简体中文](README_ch.md)
# PPOCRLabel # PPOCRLabel
PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PPOCR model to automatically detect and re-recognize data. It is written in python3 and pyqt5, supporting rectangular box, table and multi-point annotation modes. Annotations can be directly used for the training of PPOCR detection and recognition models. PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, with built-in PP-OCR model to automatically detect and re-recognize data. It is written in python3 and pyqt5, supporting rectangular box, table and multi-point annotation modes. Annotations can be directly used for the training of PP-OCR detection and recognition models.
<img src="./data/gif/steps_en.gif" width="100%"/> <img src="./data/gif/steps_en.gif" width="100%"/>
...@@ -142,14 +142,18 @@ In PPOCRLabel, complete the text information labeling (text and position), compl ...@@ -142,14 +142,18 @@ In PPOCRLabel, complete the text information labeling (text and position), compl
labeling in the Excel file, the recommended steps are: labeling in the Excel file, the recommended steps are:
1. Table annotation: After opening the table picture, click on the `Table Recognition` button in the upper right corner of PPOCRLabel, which will call the table recognition model in PP-Structure to automatically label 1. Table annotation: After opening the table picture, click on the `Table Recognition` button in the upper right corner of PPOCRLabel, which will call the table recognition model in PP-Structure to automatically label
the table and pop up Excel at the same time. the table and pop up Excel at the same time.
2. Change the recognition result: **label each cell** (i.e. the text in a cell is marked as a box). Right click on the box and click on `Cell Re-recognition`. 2. Change the recognition result: **label each cell** (i.e. the text in a cell is marked as a box). Right click on the box and click on `Cell Re-recognition`.
You can use the model to automatically recognise the text within a cell. You can use the model to automatically recognise the text within a cell.
3. Mark the table structure: for each cell contains the text, **mark as any identifier (such as `1`) in Excel**, to ensure that the merged cell structure is same as the original picture. 3. Mark the table structure: for each cell contains the text, **mark as any identifier (such as `1`) in Excel**, to ensure that the merged cell structure is same as the original picture.
4. Export JSON format annotation: close all Excel files corresponding to table images, click `File`-`Export table JSON annotation` to obtain JSON annotation results. > Note: If there are blank cells in the table, you also need to mark them with a bounding box so that the total number of cells is the same as in the image.
4. ***Adjust cell order:*** Click on the menu `View` - `Show Box Number` to show the box ordinal numbers, and drag all the results under the 'Recognition Results' column on the right side of the software interface to make the box numbers are arranged from left to right, top to bottom
5. Export JSON format annotation: close all Excel files corresponding to table images, click `File`-`Export table JSON annotation` to obtain JSON annotation results.
### 2.3 Note ### 2.3 Note
...@@ -219,14 +223,7 @@ PPOCRLabel supports three ways to export Label.txt ...@@ -219,14 +223,7 @@ PPOCRLabel supports three ways to export Label.txt
- Close application export - Close application export
### 3.4 Dataset division
### 3.4 Export Partial Recognition Results
For some data that are difficult to recognize, the recognition results will not be exported by **unchecking** the corresponding tags in the recognition results checkbox. The unchecked recognition result is saved as `True` in the `difficult` variable in the label file `label.txt`.
> *Note: The status of the checkboxes in the recognition results still needs to be saved manually by clicking Save Button.*
### 3.5 Dataset division
- Enter the following command in the terminal to execute the dataset division script: - Enter the following command in the terminal to execute the dataset division script:
...@@ -255,7 +252,7 @@ For some data that are difficult to recognize, the recognition results will not ...@@ -255,7 +252,7 @@ For some data that are difficult to recognize, the recognition results will not
| ... | ...
``` ```
### 3.6 Error message ### 3.5 Error message
- If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated. - If paddleocr is installed with whl, it has a higher priority than calling PaddleOCR class with paddleocr.py, which may cause an exception if whl package is not updated.
......
...@@ -7,8 +7,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P ...@@ -7,8 +7,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
<img src="./data/gif/steps.gif" width="100%"/> <img src="./data/gif/steps.gif" width="100%"/>
#### 近期更新 #### 近期更新
- 2022.05:新增表格标注,使用方法见下方`2.2 表格标注`(by [whjdark](https://github.com/peterh0323); [Evezerest](https://github.com/Evezerest)) - 2022.05:**新增表格标注**,使用方法见下方`2.2 表格标注`(by [whjdark](https://github.com/peterh0323); [Evezerest](https://github.com/Evezerest))
- 2022.02:新增关键信息标注、优化标注体验(by [PeterH0323](https://github.com/peterh0323) - 2022.02:**新增关键信息标注**、优化标注体验(by [PeterH0323](https://github.com/peterh0323)
- 新增:使用 `--kie` 进入 KIE 功能,用于打【检测+识别+关键字提取】的标签 - 新增:使用 `--kie` 进入 KIE 功能,用于打【检测+识别+关键字提取】的标签
- 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题。 - 提升用户体验:新增文件与标记数目提示、优化交互、修复gpu使用等问题。
- 新增功能:使用 `C``X` 对标记框进行旋转。 - 新增功能:使用 `C``X` 对标记框进行旋转。
...@@ -113,23 +113,29 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu. ...@@ -113,23 +113,29 @@ pip3 install dist/PPOCRLabel-1.0.2-py2.py3-none-any.whl -i https://mirror.baidu.
1. 安装与运行:使用上述命令安装与运行程序。 1. 安装与运行:使用上述命令安装与运行程序。
2. 打开文件夹:在菜单栏点击 “文件” - "打开目录" 选择待标记图片的文件夹<sup>[1]</sup>. 2. 打开文件夹:在菜单栏点击 “文件” - "打开目录" 选择待标记图片的文件夹<sup>[1]</sup>.
3. 自动标注:点击 ”自动标注“,使用PPOCR超轻量模型对图片文件名前图片状态<sup>[2]</sup>为 “X” 的图片进行自动标注。 3. 自动标注:点击 ”自动标注“,使用PP-OCR超轻量模型对图片文件名前图片状态<sup>[2]</sup>为 “X” 的图片进行自动标注。
4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。 4. 手动标注:点击 “矩形标注”(推荐直接在英文模式下点击键盘中的 “W”),用户可对当前图片中模型未检出的部分进行手动绘制标记框。点击键盘Q,则使用四点标注模式(或点击“编辑” - “四点标注”),用户依次点击4个点后,双击左键表示标注完成。
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。 5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup> 6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PP-OCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>
7. 内容更改:单击识别结果,对不准确的识别结果进行手动更改。 7. 内容更改:单击识别结果,对不准确的识别结果进行手动更改。
8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。** 8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。**
9. 删除:点击 “删除图像”,图片将会被删除至回收站。 9. 删除:点击 “删除图像”,图片将会被删除至回收站。
10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup> 10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*<sup>[4]</sup>
### 2.2 表格标注 ### 2.2 表格标注
表格标注针对表格的结构化提取,将图片中的表格转换为Excel格式,因此标注时需要配合外部软件打开Excel同时完成。 表格标注针对表格的结构化提取,将图片中的表格转换为Excel格式,因此标注时需要配合外部软件打开Excel同时完成。在PPOCRLabel软件中完成表格中的文字信息标注(文字与位置)、在Excel文件中完成表格结构信息标注,推荐的步骤为:
在PPOCRLabel软件中完成表格中的文字信息标注(文字与位置)、在Excel文件中完成表格结构信息标注,推荐的步骤为:
1. 表格识别:打开表格图片后,点击软件右上角 `表格识别` 按钮,软件调用PP-Structure中的表格识别模型,自动为表格打标签,同时弹出Excel 1. 表格识别:打开表格图片后,点击软件右上角 `表格识别` 按钮,软件调用PP-Structure中的表格识别模型,自动为表格打标签,同时弹出Excel
2. 更改识别结果:**以表格中的单元格为单位增加标注框**(即一个单元格内的文字都标记为一个框)。标注框上鼠标右键后点击 `单元格重识别`
2. 更改标注结果:**以表格中的单元格为单位增加标注框**(即一个单元格内的文字都标记为一个框)。标注框上鼠标右键后点击 `单元格重识别`
可利用模型自动识别单元格内的文字。 可利用模型自动识别单元格内的文字。
3. 标注表格结构:将表格图像中有文字的单元格,**在Excel中标记为任意标识符(如`1`)**,保证Excel中的单元格合并情况与原图相同即可。
4. 导出JSON格式:关闭所有表格图像对应的Excel,点击 `文件`-`导出表格JSON标注` 获得JSON标注结果。 > 注意:如果表格中存在空白单元格,同样需要使用一个标注框将其标出,使得单元格总数与图像中保持一致。
3. **调整单元格顺序:**点击软件`视图-显示框编号` 打开标注框序号,在软件界面右侧拖动 `识别结果` 一栏下的所有结果,使得标注框编号按照从左到右,从上到下的顺序排列
4. 标注表格结构:**在外部Excel软件中,将存在文字的单元格标记为任意标识符(如 `1` )**,保证Excel中的单元格合并情况与原图相同即可(即不需要Excel中的单元格文字与图片中的文字完全相同)
5. 导出JSON格式:关闭所有表格图像对应的Excel,点击 `文件`-`导出表格JSON标注` 获得JSON标注结果。
### 2.3 注意 ### 2.3 注意
...@@ -197,13 +203,7 @@ PPOCRLabel支持三种导出方式: ...@@ -197,13 +203,7 @@ PPOCRLabel支持三种导出方式:
- 关闭应用程序导出 - 关闭应用程序导出
### 3.4 导出部分识别结果 ### 3.4 数据集划分
针对部分难以识别的数据,通过在识别结果的复选框中**取消勾选**相应的标记,其识别结果不会被导出。被取消勾选的识别结果在标记文件 `label.txt` 中的 `difficult` 变量保存为 `True`
> *注意:识别结果中的复选框状态仍需用户手动点击确认后才能保留*
### 3.5 数据集划分
在终端中输入以下命令执行数据集划分脚本: 在终端中输入以下命令执行数据集划分脚本:
...@@ -232,7 +232,7 @@ python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../ ...@@ -232,7 +232,7 @@ python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../
| ... | ...
``` ```
### 3.6 错误提示 ### 3.5 错误提示
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。 - 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
......
...@@ -314,21 +314,23 @@ class Canvas(QWidget): ...@@ -314,21 +314,23 @@ class Canvas(QWidget):
QApplication.restoreOverrideCursor() # ? QApplication.restoreOverrideCursor() # ?
if self.movingShape and self.hShape: if self.movingShape and self.hShape:
index = self.shapes.index(self.hShape) if self.hShape in self.shapes:
if ( index = self.shapes.index(self.hShape)
self.shapesBackups[-1][index].points if (
!= self.shapes[index].points self.shapesBackups[-1][index].points
): != self.shapes[index].points
self.storeShapes() ):
self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py self.storeShapes()
self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
self.movingShape = False self.movingShape = False
def endMove(self, copy=False): def endMove(self, copy=False):
assert self.selectedShapes and self.selectedShapesCopy assert self.selectedShapes and self.selectedShapesCopy
assert len(self.selectedShapesCopy) == len(self.selectedShapes) assert len(self.selectedShapesCopy) == len(self.selectedShapes)
if copy: if copy:
for i, shape in enumerate(self.selectedShapesCopy): for i, shape in enumerate(self.selectedShapesCopy):
shape.idx = len(self.shapes) # add current box index
self.shapes.append(shape) self.shapes.append(shape)
self.selectedShapes[i].selected = False self.selectedShapes[i].selected = False
self.selectedShapes[i] = shape self.selectedShapes[i] = shape
...@@ -524,6 +526,9 @@ class Canvas(QWidget): ...@@ -524,6 +526,9 @@ class Canvas(QWidget):
self.storeShapes() self.storeShapes()
self.selectedShapes = [] self.selectedShapes = []
self.update() self.update()
self.updateShapeIndex()
return deleted_shapes return deleted_shapes
def storeShapes(self): def storeShapes(self):
...@@ -619,6 +624,13 @@ class Canvas(QWidget): ...@@ -619,6 +624,13 @@ class Canvas(QWidget):
pal.setColor(self.backgroundRole(), QColor(232, 232, 232, 255)) pal.setColor(self.backgroundRole(), QColor(232, 232, 232, 255))
self.setPalette(pal) self.setPalette(pal)
# adaptive BBOX label & index font size
if self.pixmap:
h, w = self.pixmap.size().height(), self.pixmap.size().width()
fontszie = int(max(h, w) / 96)
for s in self.shapes:
s.fontsize = fontszie
p.end() p.end()
def fillDrawing(self): def fillDrawing(self):
...@@ -651,7 +663,8 @@ class Canvas(QWidget): ...@@ -651,7 +663,8 @@ class Canvas(QWidget):
return return
self.current.close() self.current.close()
self.shapes.append(self.current) self.current.idx = len(self.shapes) # add current box index
self.shapes.append(self.current)
self.current = None self.current = None
self.setHiding(False) self.setHiding(False)
self.newShape.emit() self.newShape.emit()
...@@ -842,6 +855,7 @@ class Canvas(QWidget): ...@@ -842,6 +855,7 @@ class Canvas(QWidget):
self.hVertex = None self.hVertex = None
# self.hEdge = None # self.hEdge = None
self.storeShapes() self.storeShapes()
self.updateShapeIndex()
self.repaint() self.repaint()
def setShapeVisible(self, shape, value): def setShapeVisible(self, shape, value):
...@@ -883,10 +897,16 @@ class Canvas(QWidget): ...@@ -883,10 +897,16 @@ class Canvas(QWidget):
self.selectedShapes = [] self.selectedShapes = []
for shape in self.shapes: for shape in self.shapes:
shape.selected = False shape.selected = False
self.updateShapeIndex()
self.repaint() self.repaint()
@property @property
def isShapeRestorable(self): def isShapeRestorable(self):
if len(self.shapesBackups) < 2: if len(self.shapesBackups) < 2:
return False return False
return True return True
\ No newline at end of file
def updateShapeIndex(self):
for i in range(len(self.shapes)):
self.shapes[i].idx = i
self.update()
\ No newline at end of file
...@@ -21,6 +21,7 @@ SETTING_ADVANCE_MODE = 'advanced' ...@@ -21,6 +21,7 @@ SETTING_ADVANCE_MODE = 'advanced'
SETTING_WIN_STATE = 'window/state' SETTING_WIN_STATE = 'window/state'
SETTING_SAVE_DIR = 'savedir' SETTING_SAVE_DIR = 'savedir'
SETTING_PAINT_LABEL = 'paintlabel' SETTING_PAINT_LABEL = 'paintlabel'
SETTING_PAINT_INDEX = 'paintindex'
SETTING_LAST_OPEN_DIR = 'lastOpenDir' SETTING_LAST_OPEN_DIR = 'lastOpenDir'
SETTING_AUTO_SAVE = 'autosave' SETTING_AUTO_SAVE = 'autosave'
SETTING_SINGLE_CLASS = 'singleclass' SETTING_SINGLE_CLASS = 'singleclass'
......
...@@ -26,4 +26,4 @@ class EditInList(QListWidget): ...@@ -26,4 +26,4 @@ class EditInList(QListWidget):
def leaveEvent(self, event): def leaveEvent(self, event):
# close edit # close edit
for i in range(self.count()): for i in range(self.count()):
self.closePersistentEditor(self.item(i)) self.closePersistentEditor(self.item(i))
\ No newline at end of file
...@@ -46,15 +46,16 @@ class Shape(object): ...@@ -46,15 +46,16 @@ class Shape(object):
point_size = 8 point_size = 8
scale = 1.0 scale = 1.0
def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False): def __init__(self, label=None, line_color=None, difficult=False, key_cls="None", paintLabel=False, paintIdx=False):
self.label = label self.label = label
self.idx = 0 self.idx = None # bbox order, only for table annotation
self.points = [] self.points = []
self.fill = False self.fill = False
self.selected = False self.selected = False
self.difficult = difficult self.difficult = difficult
self.key_cls = key_cls self.key_cls = key_cls
self.paintLabel = paintLabel self.paintLabel = paintLabel
self.paintIdx = paintIdx
self.locked = False self.locked = False
self.direction = 0 self.direction = 0
self.center = None self.center = None
...@@ -65,6 +66,7 @@ class Shape(object): ...@@ -65,6 +66,7 @@ class Shape(object):
self.NEAR_VERTEX: (4, self.P_ROUND), self.NEAR_VERTEX: (4, self.P_ROUND),
self.MOVE_VERTEX: (1.5, self.P_SQUARE), self.MOVE_VERTEX: (1.5, self.P_SQUARE),
} }
self.fontsize = 8
self._closed = False self._closed = False
...@@ -124,7 +126,7 @@ class Shape(object): ...@@ -124,7 +126,7 @@ class Shape(object):
color = self.select_line_color if self.selected else self.line_color color = self.select_line_color if self.selected else self.line_color
pen = QPen(color) pen = QPen(color)
# Try using integer sizes for smoother drawing(?) # Try using integer sizes for smoother drawing(?)
pen.setWidth(max(1, int(round(2.0 / self.scale)))) # pen.setWidth(max(1, int(round(2.0 / self.scale))))
painter.setPen(pen) painter.setPen(pen)
line_path = QPainterPath() line_path = QPainterPath()
...@@ -155,7 +157,7 @@ class Shape(object): ...@@ -155,7 +157,7 @@ class Shape(object):
min_y = min(min_y, point.y()) min_y = min(min_y, point.y())
if min_x != sys.maxsize and min_y != sys.maxsize: if min_x != sys.maxsize and min_y != sys.maxsize:
font = QFont() font = QFont()
font.setPointSize(8) font.setPointSize(self.fontsize)
font.setBold(True) font.setBold(True)
painter.setFont(font) painter.setFont(font)
if self.label is None: if self.label is None:
...@@ -164,6 +166,25 @@ class Shape(object): ...@@ -164,6 +166,25 @@ class Shape(object):
min_y += MIN_Y_LABEL min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, self.label) painter.drawText(min_x, min_y, self.label)
# Draw number at the top-right
if self.paintIdx:
min_x = sys.maxsize
min_y = sys.maxsize
for point in self.points:
min_x = min(min_x, point.x())
min_y = min(min_y, point.y())
if min_x != sys.maxsize and min_y != sys.maxsize:
font = QFont()
font.setPointSize(self.fontsize)
font.setBold(True)
painter.setFont(font)
text = ''
if self.idx != None:
text = str(self.idx)
if min_y < MIN_Y_LABEL:
min_y += MIN_Y_LABEL
painter.drawText(min_x, min_y, text)
if self.fill: if self.fill:
color = self.select_fill_color if self.selected else self.fill_color color = self.select_fill_color if self.selected else self.fill_color
painter.fillPath(line_path, color) painter.fillPath(line_path, color)
......
...@@ -61,6 +61,7 @@ labels=Labels ...@@ -61,6 +61,7 @@ labels=Labels
autoSaveMode=Auto Save mode autoSaveMode=Auto Save mode
singleClsMode=Single Class Mode singleClsMode=Single Class Mode
displayLabel=Display Labels displayLabel=Display Labels
displayIndex=Display box index
fileList=File List fileList=File List
files=Files files=Files
advancedMode=Advanced Mode advancedMode=Advanced Mode
......
...@@ -61,6 +61,7 @@ labels=标签 ...@@ -61,6 +61,7 @@ labels=标签
autoSaveMode=自动保存模式 autoSaveMode=自动保存模式
singleClsMode=单一类别模式 singleClsMode=单一类别模式
displayLabel=显示类别 displayLabel=显示类别
displayIndex=显示box序号
fileList=文件列表 fileList=文件列表
files=文件 files=文件
advancedMode=专家模式 advancedMode=专家模式
......
...@@ -72,6 +72,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel ...@@ -72,6 +72,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" /> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="Supported-Chinese-model-list"></a> <a name="Supported-Chinese-model-list"></a>
## PP-OCR Series Model List(Update on September 8th) ## PP-OCR Series Model List(Update on September 8th)
| Model introduction | Model name | Recommended scene | Detection model | Direction classifier | Recognition model | | Model introduction | Model name | Recommended scene | Detection model | Direction classifier | Recognition model |
......
...@@ -71,6 +71,8 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ...@@ -71,6 +71,8 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
## 《动手学OCR》电子书 ## 《动手学OCR》电子书
- [《动手学OCR》电子书📚](./doc/doc_ch/ocr_book.md) - [《动手学OCR》电子书📚](./doc/doc_ch/ocr_book.md)
## 场景应用
- PaddleOCR场景应用覆盖通用,制造、金融、交通行业的主要OCR垂类应用,在PP-OCR、PP-Structure的通用能力基础之上,以notebook的形式展示利用场景数据微调、模型优化方法、数据增广等内容,为开发者快速落地OCR应用提供示范与启发。详情可查看[README](./applications)
<a name="开源社区"></a> <a name="开源社区"></a>
## 开源社区 ## 开源社区
......
# 场景应用
PaddleOCR场景应用覆盖通用,制造、金融、交通行业的主要OCR垂类应用,在PP-OCR、PP-Structure的通用能力基础之上,以notebook的形式展示利用场景数据微调、模型优化方法、数据增广等内容,为开发者快速落地OCR应用提供示范与启发。
> 如需下载全部垂类模型,可以扫描下方二维码,关注公众号填写问卷后,加入PaddleOCR官方交流群获取20G OCR学习大礼包(内含《动手学OCR》电子书、课程回放视频、前沿论文等重磅资料)
<div align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
> 如果您是企业开发者且未在下述场景中找到合适的方案,可以填写[OCR应用合作调研问卷](https://paddle.wjx.cn/vj/QwF7GKw.aspx),免费与官方团队展开不同层次的合作,包括但不限于问题抽象、确定技术方案、项目答疑、共同研发等。如果您已经使用PaddleOCR落地项目,也可以填写此问卷,与飞桨平台共同宣传推广,提升企业技术品宣。期待您的提交!
## 通用
| 类别 | 亮点 | 类别 | 亮点 |
| ------------------------------------------------- | -------- | ---------- | ------------ |
| [高精度中文识别模型SVTR](./高精度中文识别模型.md) | 新增模型 | 手写体识别 | 新增字形支持 |
## 制造
| 类别 | 亮点 | 类别 | 亮点 |
| ------------------------------------------------------------ | ------------------------------ | ------------------------------------------- | -------------------- |
| [数码管识别](./光功率计数码管字符识别/光功率计数码管字符识别.md) | 数码管数据合成、漏识别调优 | 电表识别 | 大分辨率图像检测调优 |
| [液晶屏读数识别](./液晶屏读数识别.md) | 检测模型蒸馏、Serving部署 | [PCB文字识别](./PCB字符识别/PCB字符识别.md) | 小尺寸文本检测与识别 |
| [包装生产日期](./包装生产日期识别.md) | 点阵字符合成、过曝过暗文字识别 | 液晶屏缺陷检测 | 非文字字符识别 |
## 金融
| 类别 | 亮点 | 类别 | 亮点 |
| ------------------------------ | ------------------------ | ------------ | --------------------- |
| [表单VQA](./多模态表单识别.md) | 多模态通用表单结构化提取 | 通用卡证识别 | 通用结构化提取 |
| 增值税发票 | 尽请期待 | 身份证识别 | 结构化提取、图像阴影 |
| 印章检测与识别 | 端到端弯曲文本识别 | 合同比对 | 密集文本检测、NLP串联 |
## 交通
| 类别 | 亮点 | 类别 | 亮点 |
| ------------------------------- | ------------------------------ | ---------- | -------- |
| [车牌识别](./轻量级车牌识别.md) | 多角度图像、轻量模型、端侧部署 | 快递单识别 | 尽请期待 |
| 驾驶证/行驶证识别 | 尽请期待 | | |
\ No newline at end of file
# 一种基于PaddleOCR的产品包装生产日期识别模型
- [1. 项目介绍](#1-项目介绍)
- [2. 环境搭建](#2-环境搭建)
- [3. 数据准备](#3-数据准备)
- [4. 直接使用PP-OCRv3模型评估](#4-直接使用PPOCRv3模型评估)
- [5. 基于合成数据finetune](#5-基于合成数据finetune)
- [5.1 Text Renderer数据合成方法](#51-TextRenderer数据合成方法)
- [5.1.1 下载Text Renderer代码](#511-下载TextRenderer代码)
- [5.1.2 准备背景图片](#512-准备背景图片)
- [5.1.3 准备语料](#513-准备语料)
- [5.1.4 下载字体](#514-下载字体)
- [5.1.5 运行数据合成命令](#515-运行数据合成命令)
- [5.2 模型训练](#52-模型训练)
- [6. 基于真实数据finetune](#6-基于真实数据finetune)
- [6.1 python爬虫获取数据](#61-python爬虫获取数据)
- [6.2 数据挖掘](#62-数据挖掘)
- [6.3 模型训练](#63-模型训练)
- [7. 基于合成+真实数据finetune](#7-基于合成+真实数据finetune)
## 1. 项目介绍
产品包装生产日期是计算机视觉图像识别技术在工业场景中的一种应用。产品包装生产日期识别技术要求能够将产品生产日期从复杂背景中提取并识别出来,在物流管理、物资管理中得到广泛应用。
![](https://ai-studio-static-online.cdn.bcebos.com/d9e0533cc1df47ffa3bbe99de9e42639a3ebfa5bce834bafb1ca4574bf9db684)
- 项目难点
1. 没有训练数据
2. 图像质量层次不齐: 角度倾斜、图片模糊、光照不足、过曝等问题严重
针对以上问题, 本例选用PP-OCRv3这一开源超轻量OCR系统进行包装产品生产日期识别系统的开发。直接使用PP-OCRv3进行评估的精度为62.99%。为提升识别精度,我们首先使用数据合成工具合成了3k数据,基于这部分数据进行finetune,识别精度提升至73.66%。由于合成数据与真实数据之间的分布存在差异,为进一步提升精度,我们使用网络爬虫配合数据挖掘策略得到了1k带标签的真实数据,基于真实数据finetune的精度为71.33%。最后,我们综合使用合成数据和真实数据进行finetune,将识别精度提升至86.99%。各策略的精度提升效果如下:
| 策略 | 精度|
| :--------------- | :-------- |
| PP-OCRv3评估 | 62.99|
| 合成数据finetune | 73.66|
| 真实数据finetune | 71.33|
| 真实+合成数据finetune | 86.99|
AIStudio项目链接: [一种基于PaddleOCR的包装生产日期识别方法](https://aistudio.baidu.com/aistudio/projectdetail/4287736)
## 2. 环境搭建
本任务基于Aistudio完成, 具体环境如下:
- 操作系统: Linux
- PaddlePaddle: 2.3
- PaddleOCR: Release/2.5
- text_renderer: master
下载PaddlleOCR代码并安装依赖库:
```bash
git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR
# 安装依赖库
cd PaddleOCR
pip install -r PaddleOCR/requirements.txt
```
## 3. 数据准备
本项目使用人工预标注的300张图像作为测试集。
部分数据示例如下:
![](https://ai-studio-static-online.cdn.bcebos.com/39ff30e0ab0442579712255e6a9ea6b5271169c98e624e6eb2b8781f003bfea0)
标签文件格式如下:
```txt
数据路径 标签(中间以制表符分隔)
```
|数据集类型|数量|
|---|---|
|测试集| 300|
数据集[下载链接](https://aistudio.baidu.com/aistudio/datasetdetail/149770),下载后可以通过下方命令解压:
```bash
tar -xvf data.tar
mv data ${PaddleOCR_root}
```
数据解压后的文件结构如下:
```shell
PaddleOCR
├── data
│ ├── mining_images # 挖掘的真实数据示例
│ ├── mining_train.list # 挖掘的真实数据文件列表
│ ├── render_images # 合成数据示例
│ ├── render_train.list # 合成数据文件列表
│ ├── val # 测试集数据
│ └── val.list # 测试集数据文件列表
| ├── bg # 合成数据所需背景图像
│ └── corpus # 合成数据所需语料
```
## 4. 直接使用PP-OCRv3模型评估
准备好测试数据后,可以使用PaddleOCR的PP-OCRv3模型进行识别。
- 下载预训练模型
首先需要下载PP-OCR v3中英文识别模型文件,下载链接可以在https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/ppocr_introduction.md#6 获取,下载命令:
```bash
cd ${PaddleOCR_root}
mkdir ckpt
wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
pushd ckpt/
tar -xvf ch_PP-OCRv3_rec_train.tar
popd
```
- 模型评估
使用以下命令进行PP-OCRv3评估:
```bash
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.checkpoints=ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"]
```
其中各参数含义如下:
```bash
-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。
-o: 覆盖配置文件中参数
Global.checkpoints: 指定评估使用的模型文件路径
Eval.dataset.data_dir: 指定评估数据集路径
Eval.dataset.label_file_list: 指定评估数据集文件列表
```
## 5. 基于合成数据finetune
### 5.1 Text Renderer数据合成方法
#### 5.1.1 下载Text Renderer代码
首先从github或gitee下载Text Renderer代码,并安装相关依赖。
```bash
git clone https://gitee.com/wowowoll/text_renderer.git
# 安装依赖库
cd text_renderer
pip install -r requirements.txt
```
使用text renderer合成数据之前需要准备好背景图片、语料以及字体库,下面将逐一介绍各个步骤。
#### 5.1.2 准备背景图片
观察日常生活中常见的包装生产日期图片,我们可以发现其背景相对简单。为此我们可以从网上找一下图片,截取部分图像块作为背景图像。
本项目已准备了部分图像作为背景图片,在第3部分完成数据准备后,可以得到我们准备好的背景图像,示例如下:
![](https://ai-studio-static-online.cdn.bcebos.com/456ae2acb27d4a94896c478812aee0bc3551c703d7bd40c9be4dc983c7b3fc8a)
背景图像存放于如下位置:
```shell
PaddleOCR
├── data
| ├── bg # 合成数据所需背景图像
```
#### 5.1.3 准备语料
观察测试集生产日期图像,我们可以知道如下数据有如下特点:
1. 由年月日组成,中间可能以“/”、“-”、“:”、“.”或者空格间隔,也可能以汉字年月日分隔
2. 有些生产日期包含在产品批号中,此时可能包含具体时间、英文字母或数字标识
基于以上两点,我们编写语料生成脚本:
```python
import random
from random import choice
import os
cropus_num = 2000 #设置语料数量
def get_cropus(f):
# 随机生成年份
year = random.randint(0, 22)
# 随机生成月份
month = random.randint(1, 12)
# 随机生成日期
day_dict = {31: [1,3,5,7,8,10,12], 30: [4,6,9,11], 28: [2]}
for item in day_dict:
if month in day_dict[item]:
day = random.randint(0, item)
# 随机生成小时
hours = random.randint(0, 24)
# 随机生成分钟
minute = random.randint(0, 60)
# 随机生成秒数
second = random.randint(0, 60)
# 随机生成产品标识字符
length = random.randint(0, 6)
file_id = []
flag = 0
my_dict = [i for i in range(48,58)] + [j for j in range(40, 42)] + [k for k in range(65,90)] # 大小写字母 + 括号
for i in range(1, length):
if flag:
if i == flag+2: #括号匹配
file_id.append(')')
flag = 0
continue
sel = choice(my_dict)
if sel == 41:
continue
if sel == 40:
if i == 1 or i > length-3:
continue
flag = i
my_ascii = chr(sel)
file_id.append(my_ascii)
file_id_str = ''.join(file_id)
#随机生成产品标识字符
file_id2 = random.randint(0, 9)
rad = random.random()
if rad < 0.3:
f.write('20{:02d}{:02d}{:02d} {}'.format(year, month, day, file_id_str))
elif 0.3 < rad < 0.5:
f.write('20{:02d}年{:02d}月{:02d}日'.format(year, month, day))
elif 0.5 < rad < 0.7:
f.write('20{:02d}/{:02d}/{:02d}'.format(year, month, day))
elif 0.7 < rad < 0.8:
f.write('20{:02d}-{:02d}-{:02d}'.format(year, month, day))
elif 0.8 < rad < 0.9:
f.write('20{:02d}.{:02d}.{:02d}'.format(year, month, day))
else:
f.write('{:02d}:{:02d}:{:02d} {:02d}'.format(hours, minute, second, file_id2))
if __name__ == "__main__":
file_path = '/home/aistudio/text_renderer/my_data/cropus'
if not os.path.exists(file_path):
os.makedirs(file_path)
file_name = os.path.join(file_path, 'books.txt')
f = open(file_name, 'w')
for i in range(cropus_num):
get_cropus(f)
if i < cropus_num-1:
f.write('\n')
f.close()
```
本项目已准备了部分语料,在第3部分完成数据准备后,可以得到我们准备好的语料库,默认位置如下:
```shell
PaddleOCR
├── data
│ └── corpus #合成数据所需语料
```
#### 5.1.4 下载字体
观察包装生产日期,我们可以发现其使用的字体为点阵体。字体可以在如下网址下载:
https://www.fonts.net.cn/fonts-en/tag-dianzhen-1.html
本项目已准备了部分字体,在第3部分完成数据准备后,可以得到我们准备好的字体,默认位置如下:
```shell
PaddleOCR
├── data
│ └── fonts #合成数据所需字体
```
下载好字体后,还需要在list文件中指定字体文件存放路径,脚本如下:
```bash
cd text_renderer/my_data/
touch fonts.list
ls /home/aistudio/PaddleOCR/data/fonts/* > fonts.list
```
#### 5.1.5 运行数据合成命令
完成数据准备后,my_data文件结构如下:
```shell
my_data/
├── cropus
│ └── books.txt #语料库
├── eng.txt #字符列表
└── fonts.list #字体列表
```
在运行合成数据命令之前,还有两处细节需要手动修改:
1. 将默认配置文件`text_renderer/configs/default.yaml`中第9行enable的值设为`true`,即允许合成彩色图像。否则合成的都是灰度图。
```yaml
# color boundary is in R,G,B format
font_color:
+ enable: true #false
```
2.`text_renderer/textrenderer/renderer.py`第184行作如下修改,取消padding。否则图片两端会有一些空白。
```python
padding = random.randint(s_bbox_width // 10, s_bbox_width // 8) #修改前
padding = 0 #修改后
```
运行数据合成命令:
```bash
cd /home/aistudio/text_renderer/
python main.py --num_img=3000 \
--fonts_list='./my_data/fonts.list' \
--corpus_dir "./my_data/cropus" \
--corpus_mode "list" \
--bg_dir "/home/aistudio/PaddleOCR/data/bg/" \
--img_width 0
```
合成好的数据默认保存在`text_renderer/output`目录下,可进入该目录查看合成的数据。
合成数据示例如下
![](https://ai-studio-static-online.cdn.bcebos.com/d686a48d465a43d09fbee51924fdca42ee21c50e676646da8559fb9967b94185)
数据合成好后,还需要生成如下格式的训练所需的标注文件,
```
图像路径 标签
```
使用如下脚本即可生成标注文件:
```python
import random
abspath = '/home/aistudio/text_renderer/output/default/'
#标注文件生成路径
fout = open('./render_train.list', 'w', encoding='utf-8')
with open('./output/default/tmp_labels.txt','r') as f:
lines = f.readlines()
for item in lines:
label = item[9:]
filename = item[:8] + '.jpg'
fout.write(abspath + filename + '\t' + label)
fout.close()
```
经过以上步骤,我们便完成了包装生产日期数据合成。
数据位于`text_renderer/output`,标注文件位于`text_renderer/render_train.list`
本项目提供了生成好的数据供大家体验,完成步骤3的数据准备后,可得数据路径位于:
```shell
PaddleOCR
├── data
│ ├── render_images # 合成数据示例
│ ├── render_train.list #合成数据文件列表
```
### 5.2 模型训练
准备好合成数据后,我们可以使用以下命令,利用合成数据进行finetune:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=20 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/render_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
其中各参数含义如下:
```txt
-c: 指定使用的配置文件,ch_PP-OCRv3_rec_distillation.yml对应于OCRv3识别模型。
-o: 覆盖配置文件中参数
Global.pretrained_model: 指定finetune使用的预训练模型
Global.epoch_num: 指定训练的epoch数
Global.eval_batch_step: 间隔多少step做一次评估
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
Train.loader.batch_size_per_card: 训练单卡batch size
Eval.dataset.data_dir: 评估数据集路径
Eval.dataset.label_file_list: 评估数据集文件列表
Eval.loader.batch_size_per_card: 评估单卡batch size
```
## 6. 基于真实数据finetune
使用合成数据finetune能提升我们模型的识别精度,但由于合成数据和真实数据之间的分布可能有一定差异,因此作用有限。为进一步提高识别精度,本节介绍如何挖掘真实数据进行模型finetune。
数据挖掘的整体思路如下:
1. 使用python爬虫从网上获取大量无标签数据
2. 使用模型从大量无标签数据中构建出有效训练集
### 6.1 python爬虫获取数据
- 推荐使用[爬虫工具](https://github.com/Joeclinton1/google-images-download)获取无标签图片。
图片获取后,可按如下目录格式组织:
```txt
sprider
├── file.list
├── data
│ ├── 00000.jpg
│ ├── 00001.jpg
...
```
### 6.2 数据挖掘
我们使用PaddleOCR对获取到的图片进行挖掘,具体步骤如下:
1. 使用 PP-OCRv3检测模型+svtr-tiny识别模型,对每张图片进行预测。
2. 使用数据挖掘策略,得到有效图片。
3. 将有效图片对应的图像区域和标签提取出来,构建训练集。
首先下载预训练模型,PP-OCRv3检测模型下载链接:https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
如需获取svtr-tiny高精度中文识别预训练模型,请扫码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁
<div align="left">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
完成下载后,可将模型存储于如下位置:
```shell
PaddleOCR
├── data
│ ├── rec_vit_sub_64_363_all/ # svtr_tiny高精度识别模型
```
```bash
# 下载解压PP-OCRv3检测模型
cd ${PaddleOCR_root}
wget -nc -P ckpt https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
pushd ckpt
tar -xvf ch_PP-OCRv3_det_infer.tar
popd ckpt
```
在使用PPOCRv3检测模型+svtr-tiny识别模型进行预测之前,有如下两处细节需要手动修改:
1.`tools/infer/predict_rec.py`中第110行`imgW`修改为`320`
```python
#imgW = int((imgH * max_wh_ratio))
imgW = 320
```
2.`tools/infer/predict_system.py`第169行添加如下一行,将预测分数也写入结果文件中。
```python
"scores": rec_res[idx][1],
```
模型预测命令:
```bash
python tools/infer/predict_system.py \
--image_dir="/home/aistudio/sprider/data" \
--det_model_dir="./ckpt/ch_PP-OCRv3_det_infer/" \
--rec_model_dir="/home/aistudio/PaddleOCR/data/rec_vit_sub_64_363_all/" \
--rec_image_shape="3,32,320"
```
获得预测结果后,我们使用数据挖掘策略得到有效图片。具体挖掘策略如下:
1. 预测置信度高于95%
2. 识别结果包含字符‘20’,即年份
3. 没有中文,或者有中文并且‘日’和'月'同时在识别结果中
```python
# 获取有效预测
import json
import re
zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') #正则表达式,筛选字符是否包含中文
file_path = '/home/aistudio/PaddleOCR/inference_results/system_results.txt'
out_path = '/home/aistudio/PaddleOCR/selected_results.txt'
f_out = open(out_path, 'w')
with open(file_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
flag = False
# 读取文件内容
file_name, json_file = line.strip().split('\t')
preds = json.loads(json_file)
res = []
for item in preds:
transcription = item['transcription'] #获取识别结果
scores = item['scores'] #获取识别得分
# 挖掘策略
if scores > 0.95:
if '20' in transcription and len(transcription) > 4 and len(transcription) < 12:
word = transcription
if not(zh_pattern.search(word) and ('日' not in word or '月' not in word)):
flag = True
res.append(item)
save_pred = file_name + "\t" + json.dumps(
res, ensure_ascii=False) + "\n"
if flag ==True:
f_out.write(save_pred)
f_out.close()
```
然后将有效预测对应的图像区域和标签提取出来,构建训练集。具体实现脚本如下:
```python
import cv2
import json
import numpy as np
PATH = '/home/aistudio/PaddleOCR/inference_results/' #数据原始路径
SAVE_PATH = '/home/aistudio/mining_images/' #裁剪后数据保存路径
file_list = '/home/aistudio/PaddleOCR/selected_results.txt' #数据预测结果
label_file = '/home/aistudio/mining_images/mining_train.list' #输出真实数据训练集标签list
if not os.path.exists(SAVE_PATH):
os.mkdir(SAVE_PATH)
f_label = open(label_file, 'w')
def get_rotate_crop_image(img, points):
"""
根据检测结果points,从输入图像img中裁剪出相应的区域
"""
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
# 形变或倾斜,会做透视变换,reshape成矩形
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def crop_and_get_filelist(file_list):
with open(file_list, "r", encoding='utf-8') as fin:
lines = fin.readlines()
img_num = 0
for line in lines:
img_name, json_file = line.strip().split('\t')
preds = json.loads(json_file)
for item in preds:
transcription = item['transcription']
points = item['points']
points = np.array(points).astype('float32')
#print('processing {}...'.format(img_name))
img = cv2.imread(PATH+img_name)
dst_img = get_rotate_crop_image(img, points)
h, w, c = dst_img.shape
newWidth = int((32. / h) * w)
newImg = cv2.resize(dst_img, (newWidth, 32))
new_img_name = '{:05d}.jpg'.format(img_num)
cv2.imwrite(SAVE_PATH+new_img_name, dst_img)
f_label.write(SAVE_PATH+new_img_name+'\t'+transcription+'\n')
img_num += 1
crop_and_get_filelist(file_list)
f_label.close()
```
### 6.3 模型训练
通过数据挖掘,我们得到了真实场景数据和对应的标签。接下来使用真实数据finetune,观察精度提升效果。
利用真实数据进行finetune:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=20 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/mining_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改:
```txt
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
```
示例使用我们提供的真实数据进行finetune,如想换成自己的数据,只需要相应的修改`Train.dataset.data_dir``Train.dataset.label_file_list`参数即可。
由于数据量不大,这里仅训练20个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**71.33%**
由于数量比较少,精度会比合成数据finetue的略低。
## 7. 基于合成+真实数据finetune
为进一步提升模型精度,我们结合使用合成数据和挖掘到的真实数据进行finetune。
利用合成+真实数据进行finetune,各参数含义参考第6部分合成数据finetune,只需要对训练数据路径做相应的修改:
```txt
Train.dataset.data_dir: 训练数据集路径
Train.dataset.label_file_list: 训练集文件列表
```
生成训练list文件:
```bash
# 生成训练集文件list
cat /home/aistudio/PaddleOCR/data/render_train.list /home/aistudio/PaddleOCR/data/mining_train.list > /home/aistudio/PaddleOCR/data/render_mining_train.list
```
启动训练:
```bash
cd ${PaddleOCR_root}
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml \
-o Global.pretrained_model=./ckpt/ch_PP-OCRv3_rec_train/best_accuracy \
Global.epoch_num=40 \
Global.eval_batch_step='[0, 20]' \
Train.dataset.data_dir=./data \
Train.dataset.label_file_list=['./data/render_mining_train.list'] \
Train.loader.batch_size_per_card=64 \
Eval.dataset.data_dir=./data \
Eval.dataset.label_file_list=["./data/val.list"] \
Eval.loader.batch_size_per_card=64
```
示例使用我们提供的真实+合成数据进行finetune,如想换成自己的数据,只需要相应的修改Train.dataset.data_dir和Train.dataset.label_file_list参数即可。
由于数据量不大,这里仅训练40个epoch即可。训练完成后,可以得到合成数据finetune后的精度为best acc=**86.99%**
可以看到,相较于原始PP-OCRv3的识别精度62.99%,使用合成数据+真实数据finetune后,识别精度能提升24%。
如需获取已训练模型,可以同样扫描上方二维码下载,将下载或训练完成的模型放置在对应目录下即可完成模型推理。
模型的推理部署方法可以参考repo文档: https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/deploy/README_ch.md
...@@ -28,7 +28,7 @@ Architecture: ...@@ -28,7 +28,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -45,7 +45,7 @@ Architecture: ...@@ -45,7 +45,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -65,7 +65,7 @@ Loss: ...@@ -65,7 +65,7 @@ Loss:
- ["Student", "Teacher"] - ["Student", "Teacher"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" # act: None
model_name_pairs: ["Student", "Teacher"] model_name_pairs: ["Student", "Teacher"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
......
...@@ -61,7 +61,7 @@ Architecture: ...@@ -61,7 +61,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
......
...@@ -25,7 +25,7 @@ Architecture: ...@@ -25,7 +25,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
...@@ -40,7 +40,7 @@ Architecture: ...@@ -40,7 +40,7 @@ Architecture:
model_type: det model_type: det
algorithm: DB algorithm: DB
Backbone: Backbone:
name: ResNet name: ResNet_vd
in_channels: 3 in_channels: 3
layers: 50 layers: 50
Neck: Neck:
...@@ -60,7 +60,7 @@ Loss: ...@@ -60,7 +60,7 @@ Loss:
- ["Student", "Student2"] - ["Student", "Student2"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" # act: None
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
disable_se: True disable_se: True
Neck: Neck:
......
...@@ -101,7 +101,7 @@ Train: ...@@ -101,7 +101,7 @@ Train:
drop_last: False drop_last: False
batch_size_per_card: 16 batch_size_per_card: 16
num_workers: 8 num_workers: 8
use_shared_memory: False use_shared_memory: True
Eval: Eval:
dataset: dataset:
...@@ -129,4 +129,4 @@ Eval: ...@@ -129,4 +129,4 @@ Eval:
drop_last: False drop_last: False
batch_size_per_card: 1 # must be 1 batch_size_per_card: 1 # must be 1
num_workers: 8 num_workers: 8
use_shared_memory: False use_shared_memory: True
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_icdar15/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 1152
- 2048
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
Global:
debug: false
use_gpu: true
epoch_num: 1000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_r50_td_tr/
save_epoch_step: 200
eval_batch_step:
- 0
- 2000
cal_metric_during_train: false
pretrained_model: ./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB++
Transform: null
Backbone:
name: ResNet
layers: 50
dcn_stage: [False, True, True, True]
Neck:
name: DBFPN
out_channels: 256
use_asf: True
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: BCELoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: DecayLearningRate
learning_rate: 0.007
epochs: 1000
factor: 0.9
end_lr: 0
weight_decay: 0.0001
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.5
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/train_gt_labels.txt
- ./train_data/TD_TR/TR400/gt_labels.txt
ratio_list:
- 1.0
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- IaaAugment:
augmenter_args:
- type: Fliplr
args:
p: 0.5
- type: Affine
args:
rotate:
- -10
- 10
- type: Resize
args:
size:
- 0.5
- 3
- EastRandomCropData:
size:
- 640
- 640
max_tries: 10
keep_ratio: true
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- threshold_map
- threshold_mask
- shrink_map
- shrink_mask
loader:
shuffle: true
drop_last: false
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/TD_TR/TD500/test_gt_labels.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest:
image_shape:
- 736
- 736
keep_ratio: True
- NormalizeImage:
scale: 1./255.
mean:
- 0.48109378172549
- 0.45752457890196
- 0.40787054090196
std:
- 1.0
- 1.0
- 1.0
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
profiler_options: null
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: DBFPN name: DBFPN
......
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
algorithm: FCE algorithm: FCE
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
dcn_stage: [False, True, True, True] dcn_stage: [False, True, True, True]
out_indices: [1,2,3] out_indices: [1,2,3]
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST algorithm: EAST
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: EASTFPN name: EASTFPN
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE algorithm: PSE
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: FPN name: FPN
......
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: DB algorithm: DB
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 18 layers: 18
disable_se: True disable_se: True
Neck: Neck:
......
...@@ -17,7 +17,7 @@ Global: ...@@ -17,7 +17,7 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
class_path: ./train_data/wildreceipt/class_list.txt class_path: &class_path ./train_data/wildreceipt/class_list.txt
infer_img: ./train_data/wildreceipt/1.txt infer_img: ./train_data/wildreceipt/1.txt
save_res_path: ./output/sdmgr_kie/predicts_kie.txt save_res_path: ./output/sdmgr_kie/predicts_kie.txt
img_scale: [ 1024, 512 ] img_scale: [ 1024, 512 ]
...@@ -72,6 +72,7 @@ Train: ...@@ -72,6 +72,7 @@ Train:
order: 'hwc' order: 'hwc'
- KieLabelEncode: # Class handling label - KieLabelEncode: # Class handling label
character_dict_path: ./train_data/wildreceipt/dict.txt character_dict_path: ./train_data/wildreceipt/dict.txt
class_path: *class_path
- KieResize: - KieResize:
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
...@@ -88,7 +89,6 @@ Eval: ...@@ -88,7 +89,6 @@ Eval:
data_dir: ./train_data/wildreceipt data_dir: ./train_data/wildreceipt
label_file_list: label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt - ./train_data/wildreceipt/wildreceipt_test.txt
# - /paddle/data/PaddleOCR/train_data/wildreceipt/1.txt
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -82,7 +82,7 @@ Train: ...@@ -82,7 +82,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/evaluaiton/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
Global:
use_gpu: True
epoch_num: 6
log_smooth_window: 50
print_batch_step: 50
save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
save_epoch_step: 3
# evaluation is run every 2000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4, 5]
values: [0.001, 0.0003, 0.00009, 0.000027]
clip_norm: 5
Architecture:
model_type: rec
algorithm: SPIN
in_channels: 1
Transform:
name: GA_SPIN
offsets: True
default_type: 6
loc_lr: 0.1
stn: True
Backbone:
name: ResNet32
out_channels: 512
Neck:
name: SequenceEncoder
encoder_type: cascadernn
hidden_size: 256
out_channels: [256, 512]
with_linear: True
Head:
name: SPINAttentionHead
hidden_size: 256
Loss:
name: SPINAttentionLoss
ignore_index: 0
PostProcess:
name: SPINLabelDecode
use_space_char: False
Metric:
name: RecMetric
main_indicator: acc
is_filter: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINLabelEncode: # Class handling label
- SPINRecResizeImg:
image_shape: [100, 32]
interpolation : 2
mean: [127.5]
std: [127.5]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 8
drop_last: True
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINLabelEncode: # Class handling label
- SPINRecResizeImg:
image_shape: [100, 32]
interpolation : 2
mean: [127.5]
std: [127.5]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 2
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model: ./pretrain_models/abinet_vl_pretrained
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
...@@ -82,7 +82,7 @@ Train: ...@@ -82,7 +82,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -77,7 +77,7 @@ Metric: ...@@ -77,7 +77,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
...@@ -97,7 +97,7 @@ Train: ...@@ -97,7 +97,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation data_dir: ./train_data/data_lmdb_release/evaluation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -81,7 +81,7 @@ Train: ...@@ -81,7 +81,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
...@@ -4,21 +4,20 @@ Global: ...@@ -4,21 +4,20 @@ Global:
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/table/table.jpg infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
...@@ -44,11 +43,8 @@ Architecture: ...@@ -44,11 +43,8 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
...@@ -61,28 +57,34 @@ PostProcess: ...@@ -61,28 +57,34 @@ PostProcess:
Metric: Metric:
name: TableMetric name: TableMetric
main_indicator: acc main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train: Train:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/ data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: True shuffle: True
batch_size_per_card: 32 batch_size_per_card: 32
...@@ -92,24 +94,29 @@ Train: ...@@ -92,24 +94,29 @@ Train:
Eval: Eval:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/ data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/FUNSD/class_list.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -3,16 +3,16 @@ Global: ...@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/ save_model_dir: ./output/re_layoutlmv2_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2048 seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
Backbone: Backbone:
name: LayoutLMv2ForRe name: LayoutLMv2ForRe
pretrained: True pretrained: True
checkpoints: checkpoints:
Loss: Loss:
name: LossFromOutput name: LossFromOutput
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train_v4.json
# - ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ./train_data/FUNSD/class_list.txt
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 16
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test_v4.json
# - ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
Architecture: Architecture:
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlm_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,16 +3,16 @@ Global: ...@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutlm/ save_model_dir: ./output/ser_layoutlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser_layoutlm_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -43,7 +43,7 @@ Optimizer: ...@@ -43,7 +43,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -54,7 +54,7 @@ Train: ...@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -89,7 +89,7 @@ Eval: ...@@ -89,7 +89,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -112,7 +112,7 @@ Eval: ...@@ -112,7 +112,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 100 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlmv2_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,7 +3,7 @@ Global: ...@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/ save_model_dir: ./output/ser_layoutlmv2_xfund_zh/
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
...@@ -11,8 +11,8 @@ Global: ...@@ -11,8 +11,8 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser_layoutlmv2_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -44,7 +44,7 @@ Optimizer: ...@@ -44,7 +44,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -55,7 +55,7 @@ Train: ...@@ -55,7 +55,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: output/ser_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: res_img_aug_with_gt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 100
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_wildreceipt
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg
save_res_path: ./output/ser_layoutxlm_wildreceipt/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 51
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/wildreceipt/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/wildreceipt/
label_file_list:
- ./train_data/wildreceipt/wildreceipt_train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,7 +3,7 @@ Global: ...@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm/ save_model_dir: ./output/ser_layoutxlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
...@@ -11,8 +11,8 @@ Global: ...@@ -11,8 +11,8 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser save_res_path: ./output/ser_layoutxlm_xfund_zh/res
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -43,7 +43,7 @@ Optimizer: ...@@ -43,7 +43,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -54,7 +54,7 @@ Train: ...@@ -54,7 +54,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -15,20 +15,24 @@ ...@@ -15,20 +15,24 @@
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
Running PaddleOCR text recognition model via TVM on bare metal Arm(R) Cortex(R)-M55 CPU and CMSIS-NN English | [简体中文](README_ch.md)
===============================================================
This folder contains an example of how to use TVM to run a PaddleOCR model Running PaddleOCR text recognition model on bare metal Arm(R) Cortex(R)-M55 CPU using Arm Virtual Hardware
on bare metal Cortex(R)-M55 CPU and CMSIS-NN. ======================================================================
Prerequisites This folder contains an example of how to run a PaddleOCR model on bare metal [Cortex(R)-M55 CPU](https://www.arm.com/products/silicon-ip-cpu/cortex-m/cortex-m55) using [Arm Virtual Hardware](https://www.arm.com/products/development-tools/simulation/virtual-hardware).
Running environment and prerequisites
------------- -------------
If the demo is run in the ci_cpu Docker container provided with TVM, then the following Case 1: If the demo is run in Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by [AWS](https://aws.amazon.com/marketplace/pp/prodview-urbpq7yo5va7g?sr=0-1&ref_=beagle&applicationId=AWSMPContessa)/[AWS China](https://awsmarketplace.amazonaws.cn/marketplace/pp/prodview-2y7nefntbmybu), the following software will be installed through [configure_avh.sh](./configure_avh.sh) script. It will install automatically when you run the application through [run_demo.sh](./run_demo.sh) script.
software will already be installed. You can refer to this [guide](https://arm-software.github.io/AVH/main/examples/html/MicroSpeech.html#amilaunch) to launch an Arm Virtual Hardware AMI instance.
Case 2: If the demo is run in the [ci_cpu Docker container](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_cpu) provided with [TVM](https://github.com/apache/tvm), then the following software will already be installed.
If the demo is not run in the ci_cpu Docker container, then you will need the following: Case 3: If the demo is not run in the ci_cpu Docker container, then you will need the following:
- Software required to build and run the demo (These can all be installed by running - Software required to build and run the demo (These can all be installed by running
https://github.com/apache/tvm/blob/main/docker/install/ubuntu_install_ethosu_driver_stack.sh .) tvm/docker/install/ubuntu_install_ethosu_driver_stack.sh.)
- [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps) - [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps)
- [cmake 3.19.5](https://github.com/Kitware/CMake/releases/) - [cmake 3.19.5](https://github.com/Kitware/CMake/releases/)
- [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2) - [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2)
...@@ -40,19 +44,22 @@ If the demo is not run in the ci_cpu Docker container, then you will need the fo ...@@ -40,19 +44,22 @@ If the demo is not run in the ci_cpu Docker container, then you will need the fo
pip install -r ./requirements.txt pip install -r ./requirements.txt
``` ```
In case2 and case3:
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
You will also need TVM which can either be: You will also need TVM which can either be:
- Installed from TLCPack(see [TLCPack](https://tlcpack.ai/))
- Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html)) - Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html))
- When building from source, the following need to be set in config.cmake: - When building from source, the following need to be set in config.cmake:
- set(USE_CMSISNN ON) - set(USE_CMSISNN ON)
- set(USE_MICRO ON) - set(USE_MICRO ON)
- set(USE_LLVM ON) - set(USE_LLVM ON)
- Installed from TLCPack nightly(see [TLCPack](https://tlcpack.ai/))
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
Running the demo application Running the demo application
---------------------------- ----------------------------
...@@ -62,6 +69,12 @@ Type the following command to run the bare metal text recognition application ([ ...@@ -62,6 +69,12 @@ Type the following command to run the bare metal text recognition application ([
./run_demo.sh ./run_demo.sh
``` ```
If you are not able to use Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by AWS/AWS China, specify argument --enable_FVP to 1 to make the application run on local Fixed Virtual Platforms (FVPs) executables.
```bash
./run_demo.sh --enable_FVP 1
```
If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then
the locations for these can be specified as arguments to run_demo.sh, for example: the locations for these can be specified as arguments to run_demo.sh, for example:
...@@ -70,13 +83,14 @@ the locations for these can be specified as arguments to run_demo.sh, for exampl ...@@ -70,13 +83,14 @@ the locations for these can be specified as arguments to run_demo.sh, for exampl
--ethosu_platform_path /home/tvm-user/ethosu/core_platform --ethosu_platform_path /home/tvm-user/ethosu/core_platform
``` ```
This will: With [run_demo.sh](./run_demo.sh) to run the demo application, it will:
- Set up running environment by installing the required prerequisites automatically if running in Arm Virtual Hardware Amazon AMI instance(not specify --enable_FVP to 1)
- Download a PaddleOCR text recognition model - Download a PaddleOCR text recognition model
- Use tvmc to compile the text recognition model for Cortex(R)-M55 CPU and CMSIS-NN - Use tvmc to compile the text recognition model for Cortex(R)-M55 CPU and CMSIS-NN
- Create a C header file inputs.c containing the image data as a C array - Create a C header file inputs.c containing the image data as a C array
- Create a C header file outputs.c containing a C array where the output of inference will be stored - Create a C header file outputs.c containing a C array where the output of inference will be stored
- Build the demo application - Build the demo application
- Run the demo application on a Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software - Run the demo application on a Arm Virtual Hardware based on Arm(R) Corstone(TM)-300 software
- The application will report the text on the image and the corresponding score. - The application will report the text on the image and the corresponding score.
Using your own image Using your own image
...@@ -92,9 +106,11 @@ python3 ./convert_image.py path/to/image ...@@ -92,9 +106,11 @@ python3 ./convert_image.py path/to/image
Model description Model description
----------------- -----------------
In this demo, the model we use is an English recognition model based on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md). PP-OCRv3 is the third version of the PP-OCR series model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). This series of models has the following features: The example is built on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md) English recognition model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). Since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current 2.7M English recognition model.
PP-OCRv3 is the third version of the PP-OCR series model. This series of models has the following features:
- PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M - PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M
- Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details - Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details
- Support vertical text recognition, and long text recognition - Support vertical text recognition, and long text recognition
The text recognition model in PP-OCRv3 supports more than 80 languages. In the process of model development, since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current model.
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
<!--- KIND, either express or implied. See the License for the --> <!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
[English](README.md) | 简体中文
通过TVM在 Arm(R) Cortex(R)-M55 CPU 上运行 PaddleOCR文 本能识别模型 通过TVM在 Arm(R) Cortex(R)-M55 CPU 上运行 PaddleOCR文 本能识别模型
=============================================================== ===============================================================
...@@ -85,9 +86,9 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/ ...@@ -85,9 +86,9 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/
模型描述 模型描述
----------------- -----------------
在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。 PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本。 该系列模型具有以下特点: 在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3原始文本识别模型的基础上进行适配,最终模型大小为2.7M。
PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本,该系列模型具有以下特点:
- 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。 - 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。
- 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。 - 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。
- 支持竖排文本识别,长文本识别。 - 支持竖排文本识别,长文本识别。
PP-OCRv3 中的文本识别模型支持 80 多种语言。 在模型开发过程中,由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3文本识别模型的基础上删除了不支持的算子,得到当前模型。
\ No newline at end of file
#!/bin/bash
# Copyright (c) 2022 Arm Limited and Contributors. All rights reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
set -o pipefail
# Show usage
function show_usage() {
cat <<EOF
Usage: Set up running environment by installing the required prerequisites.
-h, --help
Display this help message.
EOF
}
if [ "$#" -eq 1 ] && [ "$1" == "--help" -o "$1" == "-h" ]; then
show_usage
exit 0
elif [ "$#" -ge 1 ]; then
show_usage
exit 1
fi
echo -e "\e[36mStart setting up running environment\e[0m"
# Install CMSIS
echo -e "\e[36mStart installing CMSIS\e[0m"
CMSIS_PATH="/opt/arm/ethosu/cmsis"
mkdir -p "${CMSIS_PATH}"
CMSIS_SHA="977abe9849781a2e788b02282986480ff4e25ea6"
CMSIS_SHASUM="86c88d9341439fbb78664f11f3f25bc9fda3cd7de89359324019a4d87d169939eea85b7fdbfa6ad03aa428c6b515ef2f8cd52299ce1959a5444d4ac305f934cc"
CMSIS_URL="http://github.com/ARM-software/CMSIS_5/archive/${CMSIS_SHA}.tar.gz"
DOWNLOAD_PATH="/tmp/${CMSIS_SHA}.tar.gz"
wget ${CMSIS_URL} -O "${DOWNLOAD_PATH}"
echo "$CMSIS_SHASUM" ${DOWNLOAD_PATH} | sha512sum -c
tar -xf "${DOWNLOAD_PATH}" -C "${CMSIS_PATH}" --strip-components=1
touch "${CMSIS_PATH}"/"${CMSIS_SHA}".sha
echo -e "\e[36mCMSIS Installation SUCCESS\e[0m"
# Install Arm(R) Ethos(TM)-U NPU driver stack
echo -e "\e[36mStart installing Arm(R) Ethos(TM)-U NPU driver stack\e[0m"
git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform" /opt/arm/ethosu/core_platform
cd /opt/arm/ethosu/core_platform
git checkout tags/"21.11"
echo -e "\e[36mArm(R) Ethos(TM)-U Core Platform Installation SUCCESS\e[0m"
# Install Arm(R) GNU Toolchain
echo -e "\e[36mStart installing Arm(R) GNU Toolchain\e[0m"
mkdir -p /opt/arm/gcc-arm-none-eabi
export gcc_arm_url='https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2?revision=ca0cbf9c-9de2-491c-ac48-898b5bbc0443&la=en&hash=68760A8AE66026BCF99F05AC017A6A50C6FD832A'
curl --retry 64 -sSL ${gcc_arm_url} | tar -C /opt/arm/gcc-arm-none-eabi --strip-components=1 -jx
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
arm-none-eabi-gcc --version
arm-none-eabi-g++ --version
echo -e "\e[36mArm(R) Arm(R) GNU Toolchain Installation SUCCESS\e[0m"
# Install TVM from TLCPack
echo -e "\e[36mStart installing TVM\e[0m"
pip install tlcpack-nightly -f https://tlcpack.ai/wheels
echo -e "\e[36mTVM Installation SUCCESS\e[0m"
\ No newline at end of file
...@@ -24,18 +24,28 @@ function show_usage() { ...@@ -24,18 +24,28 @@ function show_usage() {
cat <<EOF cat <<EOF
Usage: run_demo.sh Usage: run_demo.sh
-h, --help -h, --help
Display this help message. Display this help message.
--cmsis_path CMSIS_PATH --cmsis_path CMSIS_PATH
Set path to CMSIS. Set path to CMSIS.
--ethosu_platform_path ETHOSU_PLATFORM_PATH --ethosu_platform_path ETHOSU_PLATFORM_PATH
Set path to Arm(R) Ethos(TM)-U core platform. Set path to Arm(R) Ethos(TM)-U core platform.
--fvp_path FVP_PATH --fvp_path FVP_PATH
Set path to FVP. Set path to FVP.
--cmake_path --cmake_path
Set path to cmake. Set path to cmake.
--enable_FVP
Set 1 to run application on local Fixed Virtual Platforms (FVPs) executables.
EOF EOF
} }
# Configure environment variables
FVP_enable=0
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
# Install python libraries
echo -e "\e[36mInstall python libraries\e[0m"
sudo pip install -r ./requirements.txt
# Parse arguments # Parse arguments
while (( $# )); do while (( $# )); do
case "$1" in case "$1" in
...@@ -91,6 +101,18 @@ while (( $# )); do ...@@ -91,6 +101,18 @@ while (( $# )); do
exit 1 exit 1
fi fi
;; ;;
--enable_FVP)
if [ $# -gt 1 ] && [ "$2" == "1" -o "$2" == "0" ];
then
FVP_enable="$2"
shift 2
else
echo 'ERROR: --enable_FVP requires a right argument 1 or 0' >&2
show_usage >&2
exit 1
fi
;;
-*|--*) -*|--*)
echo "Error: Unknown flag: $1" >&2 echo "Error: Unknown flag: $1" >&2
...@@ -100,17 +122,27 @@ while (( $# )); do ...@@ -100,17 +122,27 @@ while (( $# )); do
esac esac
done done
# Choose running environment: cloud(default) or local environment
Platform="VHT_Corstone_SSE-300_Ethos-U55"
if [ $FVP_enable == "1" ]; then
Platform="FVP_Corstone_SSE-300_Ethos-U55"
echo -e "\e[36mRun application on local Fixed Virtual Platforms (FVPs)\e[0m"
else
if [ ! -d "/opt/arm/" ]; then
sudo ./configure_avh.sh
fi
fi
# Directories # Directories
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# Make build directory # Make build directory
rm -rf build
make cleanall make cleanall
mkdir -p build mkdir -p build
cd build cd build
# Get PaddlePaddle inference model
echo -e "\e[36mDownload PaddlePaddle inference model\e[0m"
wget https://paddleocr.bj.bcebos.com/tvm/ocr_en.tar wget https://paddleocr.bj.bcebos.com/tvm/ocr_en.tar
tar -xf ocr_en.tar tar -xf ocr_en.tar
...@@ -144,9 +176,9 @@ cd ${script_dir} ...@@ -144,9 +176,9 @@ cd ${script_dir}
echo ${script_dir} echo ${script_dir}
make make
# Run demo executable on the FVP # Run demo executable on the AVH
FVP_Corstone_SSE-300_Ethos-U55 -C cpu0.CFGDTCMSZ=15 \ $Platform -C cpu0.CFGDTCMSZ=15 \
-C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ -C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \
-C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \
-C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ -C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \
./build/demo ./build/demo --stat
\ No newline at end of file
...@@ -34,12 +34,13 @@ cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio, int rec_image_height) { ...@@ -34,12 +34,13 @@ cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio, int rec_image_height) {
resize_w = imgW; resize_w = imgW;
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR); cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
int(imgW - resize_img.cols), cv::BORDER_CONSTANT, int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
{127, 127, 127}); {127, 127, 127});
return resize_img;
} }
std::vector<std::string> ReadDict(std::string path) { std::vector<std::string> ReadDict(std::string path) {
......
...@@ -474,7 +474,7 @@ void system(char **argv){ ...@@ -474,7 +474,7 @@ void system(char **argv){
std::vector<double> rec_times; std::vector<double> rec_times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score, RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, use_direction_classify, &rec_times); charactor_dict, cls_predictor, use_direction_classify, &rec_times, rec_image_height);
//// visualization //// visualization
auto img_vis = Visualization(srcimg, boxes); auto img_vis = Visualization(srcimg, boxes);
......
...@@ -5,9 +5,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,已支持的 ...@@ -5,9 +5,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,已支持的
- [文本检测算法](./algorithm_overview.md#11-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E7%AE%97%E6%B3%95) - [文本检测算法](./algorithm_overview.md#11-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E7%AE%97%E6%B3%95)
- [文本识别算法](./algorithm_overview.md#12-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95) - [文本识别算法](./algorithm_overview.md#12-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [端到端算法](./algorithm_overview.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95) - [端到端算法](./algorithm_overview.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [表格识别]](./algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
**欢迎广大开发者合作共建,贡献更多算法,合入有奖🎁!具体可查看[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)。** **欢迎广大开发者合作共建,贡献更多算法,合入有奖🎁!具体可查看[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)。**
新增算法可参考如下教程: 新增算法可参考如下教程:
- [使用PaddleOCR架构添加新算法](./add_new_algorithm.md) - [使用PaddleOCR架构添加新算法](./add_new_algorithm.md)
\ No newline at end of file
# DB # DB与DB++
- [1. 算法简介](#1) - [1. 算法简介](#1)
- [2. 环境配置](#2) - [2. 环境配置](#2)
...@@ -21,12 +21,24 @@ ...@@ -21,12 +21,24 @@
> Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang > Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang
> AAAI, 2020 > AAAI, 2020
> [Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion](https://arxiv.org/abs/2202.10304)
> Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang
> TPAMI, 2022
在ICDAR2015文本检测公开数据集上,算法复现效果如下: 在ICDAR2015文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| |模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- |
|DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|[configs/det/det_r50_vd_db.yml](../../configs/det/det_r50_vd_db.yml)|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|[configs/det/det_mv3_db.yml](../../configs/det/det_mv3_db.yml)|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|DB++|ResNet50|[configs/det/det_r50_db++_icdar15.yml](../../configs/det/det_r50_db++_icdar15.yml)|90.89%|82.66%|86.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_icdar15_train.tar)|
在TD_TR文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|DB++|ResNet50|[configs/det/det_r50_db++_td_tr.yml](../../configs/det/det_r50_db++_td_tr.yml)|92.92%|86.48%|89.58%|[合成数据预训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams)/[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_db%2B%2B_td_tr_train.tar)|
<a name="2"></a> <a name="2"></a>
...@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai ...@@ -54,7 +66,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrai
DB文本检测模型推理,可以执行如下命令: DB文本检测模型推理,可以执行如下命令:
```shell ```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" --det_algorithm="DB"
``` ```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: 可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
...@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式: ...@@ -96,4 +108,12 @@ DB模型还支持以下推理部署方式:
pages={11474--11481}, pages={11474--11481},
year={2020} year={2020}
} }
```
\ No newline at end of file @article{liao2022real,
title={Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion},
author={Liao, Minghui and Zou, Zhisheng and Wan, Zhaoyi and Yao, Cong and Bai, Xiang},
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
year={2022},
publisher={IEEE}
}
```
# FCENet # FCENet
- [1. 算法简介](#1) - [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2) - [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3) - [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [3.1 训练](#3-1) - [4. 推理部署](#4-推理部署)
- [3.2 评估](#3-2) - [4.1 Python推理](#41-python推理)
- [3.3 预测](#3-3) - [4.2 C++推理](#42-c推理)
- [4. 推理部署](#4) - [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.1 Python推理](#4-1) - [4.4 更多推理部署](#44-更多推理部署)
- [4.2 C++推理](#4-2) - [5. FAQ](#5-faq)
- [4.3 Serving服务化部署](#4-3) - [引用](#引用)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a> <a name="1"></a>
## 1. 算法简介 ## 1. 算法简介
......
# OCR算法 # OCR算法
- [1. 两阶段算法](#1-两阶段算法) - [1. 两阶段算法](#1)
- [1.1 文本检测算法](#11-文本检测算法) - [1.1 文本检测算法](#11)
- [1.2 文本识别算法](#12-文本识别算法) - [1.2 文本识别算法](#12)
- [2. 端到端算法](#2-端到端算法) - [2. 端到端算法](#2)
- [3. 表格识别算法](#3)
本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md) 本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md)
...@@ -68,6 +69,7 @@ ...@@ -68,6 +69,7 @@
- [x] [SVTR](./algorithm_rec_svtr.md) - [x] [SVTR](./algorithm_rec_svtr.md)
- [x] [ViTSTR](./algorithm_rec_vitstr.md) - [x] [ViTSTR](./algorithm_rec_vitstr.md)
- [x] [ABINet](./algorithm_rec_abinet.md) - [x] [ABINet](./algorithm_rec_abinet.md)
- [x] [SPIN](./algorithm_rec_spin.md)
参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: 参考[DTRB](https://arxiv.org/abs/1904.01906)[3]文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
...@@ -86,8 +88,10 @@ ...@@ -86,8 +88,10 @@
|SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [训练模型](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
<a name="2"></a> <a name="2"></a>
...@@ -95,3 +99,16 @@ ...@@ -95,3 +99,16 @@
已支持的端到端OCR算法列表(戳链接获取使用教程): 已支持的端到端OCR算法列表(戳链接获取使用教程):
- [x] [PGNet](./algorithm_e2e_pgnet.md) - [x] [PGNet](./algorithm_e2e_pgnet.md)
<a name="3"></a>
## 3. 表格识别算法
已支持的表格识别算法列表(戳链接获取使用教程):
- [x] [TableMaster](./algorithm_table_master.md)
在PubTabNet表格识别公开数据集上,算法效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
> AAAI, 2020
SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识别中,矫正网络是一种较为常见的前置处理模块,但诸如RARE\ASTER\ESIR等只考虑了空间变换,并没有考虑色度变换。本文提出了一种结构Structure-Preserving Inner Offset Network (SPIN),可以在色彩空间上进行变换。该模块是可微分的,可以加入到任意识别器中。
使用MJSynth和SynthText两个合成文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法复现效果如下:
|模型|骨干网络|配置文件|Acc|下载链接|
| --- | --- | --- | --- | --- |
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
请参考[文本识别教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。
训练
具体地,在完成数据准备后,便可以启动训练,训练命令如下:
```
#单卡训练(训练周期长,不建议)
python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
#多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
```
评估
```
# GPU 评估, Global.pretrained_model 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
预测:
```
# 预测使用的配置文件必须与训练一致
python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将SPIN文本识别训练过程中保存的模型,转换成inference model。可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
```
SPIN文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=Falsee
```
<a name="4-2"></a>
### 4.2 C++推理
由于C++预处理后处理还未支持SPIN,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{2020SPIN,
title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
journal={AAAI2020},
year={2020},
}
```
# 表格识别算法-TableMASTER
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
在PubTabNet表格识别公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
上述TableMaster模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)
数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。以基于TableResNetExtra骨干网络,在PubTabNet数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
转换成功后,在目录下有三个文件:
```
./inference/table_master/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
cd ppstructure/
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
```
执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
结果如下:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**注意**
- TableMaster在推理时比较慢,建议使用GPU进行使用。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持TableMaster,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
# 场景应用
PaddleOCR场景应用覆盖通用,制造、金融、交通行业的主要OCR垂类应用,在PP-OCR、PP-Structure的通用能力基础之上,以notebook的形式展示利用场景数据微调、模型优化方法、数据增广等内容,为开发者快速落地OCR应用提供示范与启发。
> 如需下载全部垂类模型,可以扫描下方二维码,关注公众号填写问卷后,加入PaddleOCR官方交流群获取20G OCR学习大礼包(内含《动手学OCR》电子书、课程回放视频、前沿论文等重磅资料)
<div align="center">
<img src="https://ai-studio-static-online.cdn.bcebos.com/dd721099bd50478f9d5fb13d8dd00fad69c22d6848244fd3a1d3980d7fefc63e" width = "150" height = "150" />
</div>
> 如果您是企业开发者且未在下述场景中找到合适的方案,可以填写[OCR应用合作调研问卷](https://paddle.wjx.cn/vj/QwF7GKw.aspx),免费与官方团队展开不同层次的合作,包括但不限于问题抽象、确定技术方案、项目答疑、共同研发等。如果您已经使用PaddleOCR落地项目,也可以填写此问卷,与飞桨平台共同宣传推广,提升企业技术品宣。期待您的提交!
## 通用
| 类别 | 亮点 | 类别 | 亮点 |
| ---------------------- | -------- | ---------- | ------------ |
| 高精度中文识别模型SVTR | 新增模型 | 手写体识别 | 新增字形支持 |
## 制造
| 类别 | 亮点 | 类别 | 亮点 |
| -------------- | ------------------------------ | -------------- | -------------------- |
| 数码管识别 | 数码管数据合成、漏识别调优 | 电表识别 | 大分辨率图像检测调优 |
| 液晶屏读数识别 | 检测模型蒸馏、Serving部署 | PCB文字识别 | 小尺寸文本检测与识别 |
| 包装生产日期 | 点阵字符合成、过曝过暗文字识别 | 液晶屏缺陷检测 | 非文字形态识别 |
## 金融
| 类别 | 亮点 | 类别 | 亮点 |
| -------------- | ------------------------ | ------------ | --------------------- |
| 表单VQA | 多模态通用表单结构化提取 | 通用卡证识别 | 通用结构化提取 |
| 增值税发票 | 尽请期待 | 身份证识别 | 结构化提取、图像阴影 |
| 印章检测与识别 | 端到端弯曲文本识别 | 合同比对 | 密集文本检测、NLP串联 |
## 交通
| 类别 | 亮点 | 类别 | 亮点 |
| ----------------- | ------------------------------ | ---------- | -------- |
| 车牌识别 | 多角度图像、轻量模型、端侧部署 | 快递单识别 | 尽请期待 |
| 驾驶证/行驶证识别 | 尽请期待 | | |
\ No newline at end of file
...@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -34,6 +34,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
| ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) | | ICDAR 2015 |https://rrc.cvc.uab.es/?ch=4&com=downloads| [train](https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt) / [test](https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt) |
| ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 | | ctw1500 |https://paddleocr.bj.bcebos.com/dataset/ctw1500.zip| 图片下载地址中已包含 |
| total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 | | total text |https://paddleocr.bj.bcebos.com/dataset/total_text.tar| 图片下载地址中已包含 |
| td tr |https://paddleocr.bj.bcebos.com/dataset/TD_TR.tar| 图片下载地址中已包含 |
#### 1.2.1 ICDAR 2015 #### 1.2.1 ICDAR 2015
ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。 ICDAR 2015 数据集包含1000张训练图像和500张测试图像。ICDAR 2015 数据集可以从上表中链接下载,首次下载需注册。
......
...@@ -65,7 +65,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/ ...@@ -65,7 +65,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
``` ```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 上述指令中,通过-c 选择训练使用configs/det/det_mv3_db.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md) 有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
- [1. 文本检测模型推理](#1-文本检测模型推理) - [1. 文本检测模型推理](#1-文本检测模型推理)
- [2. 文本识别模型推理](#2-文本识别模型推理) - [2. 文本识别模型推理](#2-文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理) - [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#22-多语言模型的推理) - [2.2 英文识别模型推理](#22-英文识别模型推理)
- [2.3 多语言模型的推理](#23-多语言模型的推理)
- [3. 方向分类模型推理](#3-方向分类模型推理) - [3. 方向分类模型推理](#3-方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理) - [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理)
...@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" ...@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg"
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379) Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379)
``` ```
<a name="英文识别模型推理"></a>
### 2.2 英文识别模型推理
英文识别模型推理,可以执行如下命令, 注意修改字典路径:
```
# 下载英文数字识别模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
执行命令后,上图的预测结果为:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="多语言模型的推理"></a> <a name="多语言模型的推理"></a>
### 2.2 多语言模型的推理 ### 2.3 多语言模型的推理
如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别: 如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
``` ```
......
...@@ -6,5 +6,6 @@ PaddleOCR will add cutting-edge OCR algorithms and models continuously. Check ou ...@@ -6,5 +6,6 @@ PaddleOCR will add cutting-edge OCR algorithms and models continuously. Check ou
- [text detection algorithms](./algorithm_overview_en.md#11) - [text detection algorithms](./algorithm_overview_en.md#11)
- [text recognition algorithms](./algorithm_overview_en.md#12) - [text recognition algorithms](./algorithm_overview_en.md#12)
- [end-to-end algorithms](./algorithm_overview_en.md#2) - [end-to-end algorithms](./algorithm_overview_en.md#2)
- [table recognition algorithms](./algorithm_overview_en.md#3)
Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline. Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline.
\ No newline at end of file
# OCR Algorithms # OCR Algorithms
- [1. Two-stage Algorithms](#1) - [1. Two-stage Algorithms](#1)
* [1.1 Text Detection Algorithms](#11) - [1.1 Text Detection Algorithms](#11)
* [1.2 Text Recognition Algorithms](#12) - [1.2 Text Recognition Algorithms](#12)
- [2. End-to-end Algorithms](#2) - [2. End-to-end Algorithms](#2)
- [3. Table Recognition Algorithms](#3)
This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md). This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md).
...@@ -67,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial): ...@@ -67,6 +68,7 @@ Supported text recognition algorithms (Click the link to get the tutorial):
- [x] [SVTR](./algorithm_rec_svtr_en.md) - [x] [SVTR](./algorithm_rec_svtr_en.md)
- [x] [ViTSTR](./algorithm_rec_vitstr_en.md) - [x] [ViTSTR](./algorithm_rec_vitstr_en.md)
- [x] [ABINet](./algorithm_rec_abinet_en.md) - [x] [ABINet](./algorithm_rec_abinet_en.md)
- [x] [SPIN](./algorithm_rec_spin_en.md)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
...@@ -85,8 +87,10 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -85,8 +87,10 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
|SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) | |SVTR|SVTR-Tiny| 89.25% | rec_svtr_tiny_none_ctc_en | [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) |
|ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce_en | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) | |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) |
|ABINet|Resnet45| 90.75% | rec_r45_abinet_en | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) |
|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon |
<a name="2"></a> <a name="2"></a>
...@@ -94,3 +98,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -94,3 +98,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
Supported end-to-end algorithms (Click the link to get the tutorial): Supported end-to-end algorithms (Click the link to get the tutorial):
- [x] [PGNet](./algorithm_e2e_pgnet_en.md) - [x] [PGNet](./algorithm_e2e_pgnet_en.md)
<a name="3"></a>
## 3. Table Recognition Algorithms
Supported table recognition algorithms (Click the link to get the tutorial):
- [x] [TableMaster](./algorithm_table_master_en.md)
On the PubTabNet dataset, the algorithm result is as follows:
|Model|Backbone|Config|Acc|Download link|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition](https://arxiv.org/abs/2005.13117)
> Chengwei Zhang, Yunlu Xu, Zhanzhan Cheng, Shiliang Pu, Yi Niu, Fei Wu, Futai Zou
> AAAI, 2020
Using MJSynth and SynthText two text recognition datasets for training, and evaluating on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE datasets. The algorithm reproduction effect is as follows:
|Model|Backbone|config|Acc|Download link|
| --- | --- | --- | --- | --- |
|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different recognition models only requires **changing the configuration file**.
Training:
Specifically, after the data preparation is completed, the training can be started. The training command is as follows:
```
#Single GPU training (long training period, not recommended)
python3 tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
#Multi GPU training, specify the gpu number through the --gpus parameter
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml
```
Evaluation:
```
# GPU evaluation
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy
```
Prediction:
```
# The configuration file used for prediction must match the training
python3 tools/infer_rec.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, the model saved during the SPIN text recognition training process is converted into an inference model. you can use the following command to convert:
```
python3 tools/export_model.py -c configs/rec/rec_r32_gaspin_bilstm_att.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/rec_r32_gaspin_bilstm_att
```
For SPIN text recognition model inference, the following commands can be executed:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./inference/rec_r32_gaspin_bilstm_att/" --rec_image_shape="3, 32, 100" --rec_algorithm="SPIN" --rec_char_dict_path="/ppocr/utils/dict/spin_dict.txt" --use_space_char=False
```
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{2020SPIN,
title={SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition},
author={Chengwei Zhang and Yunlu Xu and Zhanzhan Cheng and Shiliang Pu and Yi Niu and Fei Wu and Futai Zou},
journal={AAAI2020},
year={2020},
}
```
# Table Recognition Algorithm-TableMASTER
- [1. Introduction](#1-introduction)
- [2. Environment](#2-environment)
- [3. Model Training / Evaluation / Prediction](#3-model-training--evaluation--prediction)
- [4. Inference and Deployment](#4-inference-and-deployment)
- [4.1 Python Inference](#41-python-inference)
- [4.2 C++ Inference](#42-c-inference)
- [4.3 Serving](#43-serving)
- [4.4 More](#44-more)
- [5. FAQ](#5-faq)
- [Citation](#citation)
<a name="1"></a>
## 1. Introduction
Paper:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
On the PubTabNet table recognition public data set, the algorithm reproduction acc is as follows:
|Model|Backbone|Cnnfig|Acc|Download link|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
The above TableMaster model is trained using the PubTabNet table recognition public dataset. For the download of the dataset, please refer to [table_datasets](./dataset/table_datasets_en.md).
After the data download is complete, please refer to [Text Recognition Training Tutorial](./recognition_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different models.
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, convert the model saved in the TableMaster table recognition training process into an inference model. Taking the model based on the TableResNetExtra backbone network and trained on the PubTabNet dataset as example ([model download link](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**Note: **
- If you trained the model on your own dataset and adjusted the dictionary file, please pay attention to whether the `character_dict_path` in the modified configuration file is the correct dictionary file
Execute the following command for model inference:
```shell
cd ppstructure/
# When predicting all images in a folder, you can modify image_dir to a folder, such as --image_dir='docs/table'.
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
```
After executing the command, the prediction results of the above image (structural information and the coordinates of each cell in the table) are printed to the screen, and the visualization of the cell coordinates is also saved. An example is as follows:
result:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**Note**:
- TableMaster is relatively slow during inference, and it is recommended to use GPU for use.
<a name="4-2"></a>
### 4.2 C++ Inference
Since the post-processing is not written in CPP, the TableMaster does not support CPP inference.
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
...@@ -51,7 +51,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \ ...@@ -51,7 +51,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
``` ```
In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file. In the above instruction, use `-c` to select the training to use the `configs/det/det_mv3_db.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](./config_en.md). For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001 You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
......
...@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo ...@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo
- [Text Detection Model Inference](#text-detection-model-inference) - [Text Detection Model Inference](#text-detection-model-inference)
- [Text Recognition Model Inference](#text-recognition-model-inference) - [Text Recognition Model Inference](#text-recognition-model-inference)
- [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference) - [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference)
- [2. Multilingual Model Inference](#2-multilingual-model-inference) - [2. English Recognition Model Inference](#2-english-recognition-model-inference)
- [3. Multilingual Model Inference](#3-multilingual-model-inference)
- [Angle Classification Model Inference](#angle-classification-model-inference) - [Angle Classification Model Inference](#angle-classification-model-inference)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation) - [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation)
...@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score) ...@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score)
```bash ```bash
Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671) Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671)
``` ```
<a name="2-english-recognition-model-inference"></a>
### 2. English Recognition Model Inference
<a name="MULTILINGUAL_MODEL_INFERENCE"></a> For English recognition model inference, you can execute the following commands,you need to specify the dictionary path used by `--rec_char_dict_path`:
### 2. Multilingual Model Inference ```
# download en model:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
After executing the command, the prediction result of the above figure is:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="3-multilingual-model-inference"></a>
### 3. Multilingual Model Inference
If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results, If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition: You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
......
...@@ -23,9 +23,10 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask ...@@ -23,9 +23,10 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
from .make_pse_gt import MakePseGt from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \ from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, SPINRecResizeImg
from .ssl_img_aug import SSLRotateResize from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
...@@ -36,7 +37,7 @@ from .label_ops import * ...@@ -36,7 +37,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import * from .table_ops import *
from .vqa import * from .vqa import *
......
...@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object): ...@@ -259,15 +259,26 @@ class E2ELabelEncodeTrain(object):
class KieLabelEncode(object): class KieLabelEncode(object):
def __init__(self, character_dict_path, norm=10, directed=False, **kwargs): def __init__(self,
character_dict_path,
class_path,
norm=10,
directed=False,
**kwargs):
super(KieLabelEncode, self).__init__() super(KieLabelEncode, self).__init__()
self.dict = dict({'': 0}) self.dict = dict({'': 0})
self.label2classid_map = dict()
with open(character_dict_path, 'r', encoding='utf-8') as fr: with open(character_dict_path, 'r', encoding='utf-8') as fr:
idx = 1 idx = 1
for line in fr: for line in fr:
char = line.strip() char = line.strip()
self.dict[char] = idx self.dict[char] = idx
idx += 1 idx += 1
with open(class_path, "r") as fin:
lines = fin.readlines()
for idx, line in enumerate(lines):
line = line.strip("\n")
self.label2classid_map[line] = idx
self.norm = norm self.norm = norm
self.directed = directed self.directed = directed
...@@ -408,7 +419,7 @@ class KieLabelEncode(object): ...@@ -408,7 +419,7 @@ class KieLabelEncode(object):
text_ind = [self.dict[c] for c in text if c in self.dict] text_ind = [self.dict[c] for c in text if c in self.dict]
text_inds.append(text_ind) text_inds.append(text_ind)
if 'label' in ann.keys(): if 'label' in ann.keys():
labels.append(ann['label']) labels.append(self.label2classid_map[ann['label']])
elif 'key_cls' in ann.keys(): elif 'key_cls' in ann.keys():
labels.append(ann['key_cls']) labels.append(ann['key_cls'])
else: else:
...@@ -551,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -551,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx return idx
class TableLabelEncode(object): class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
max_text_length, max_text_length,
max_elem_length,
max_cell_num,
character_dict_path, character_dict_path,
span_weight=1.0, replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs): **kwargs):
self.max_text_length = max_text_length self.max_text_len = max_text_length
self.max_elem_length = max_elem_length self.lower = False
self.max_cell_num = max_cell_num self.learn_empty_box = learn_empty_box
list_character, list_elem = self.load_char_elem_dict( self.merge_no_span_structure = merge_no_span_structure
character_dict_path) self.replace_empty_cell_token = replace_empty_cell_token
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) dict_character = []
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t") for line in lines:
character_num = int(substr[0]) line = line.decode('utf-8').strip("\n").strip("\r\n")
elem_num = int(substr[1]) dict_character.append(line)
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n") dict_character = self.add_special_char(dict_character)
list_character.append(character) self.dict = {}
for eno in range(1 + character_num, 1 + character_num + elem_num): for i, char in enumerate(dict_character):
elem = lines[eno].decode('utf-8').strip("\r\n") self.dict[char] = i
list_elem.append(elem) self.idx2char = {v: k for k, v in self.dict.items()}
return list_character, list_elem
self.character = dict_character
def add_special_char(self, list_character): self.point_num = point_num
self.beg_str = "sos" self.pad_idx = self.dict[self.beg_str]
self.end_str = "eos" self.start_idx = self.dict[self.beg_str]
list_character = [self.beg_str] + list_character + [self.end_str] self.end_idx = self.dict[self.end_str]
return list_character
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.empty_bbox_token_dict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
'<eb10></eb10>',
}
def get_span_idx_list(self): @property
span_idx_list = [] def _max_text_len(self):
for elem in self.dict_elem: return self.max_text_len + 2
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
def __call__(self, data): def __call__(self, data):
cells = data['cells'] cells = data['cells']
structure = data['structure']['tokens'] structure = data['structure']
structure = self.encode(structure, 'elem') if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
structure = self._replace_empty_cell_token(structure, cells)
# remove empty token and add " " to span token
new_structure = []
for token in structure:
if token != '':
if 'span' in token and token[0] != ' ':
token = ' ' + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None: if structure is None:
return None return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1] structure = [self.start_idx] + structure + [self.end_idx
structure = structure + [0] * (self.max_elem_length + 2 - len(structure) ] # add sos abd eos
) structure = structure + [self.pad_idx] * (self._max_text_len -
len(structure)) # pad
structure = np.array(structure) structure = np.array(structure)
data['structure'] = structure data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td'] if len(structure) > self._max_text_len:
span_idx_list = self.get_span_idx_list() return None
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2) # encode box
td_idx_list = np.where(td_idx_list)[0] bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
structure_mask = np.ones( bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32) bbox_idx = 0
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32) for i, token in enumerate(structure):
img_height, img_width, img_ch = data['image'].shape if self.idx2char[token] in self.td_token:
if len(span_idx_list) > 0: if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) 'tokens']) > 0:
span_weight = min(max(span_weight, 1.0), self.span_weight) bbox = cells[bbox_idx]['bbox'].copy()
for cno in range(len(cells)): bbox = np.array(bbox, dtype=np.float32).reshape(-1)
if 'bbox' in cells[cno]: bboxes[i] = bbox
bbox = cells[cno]['bbox'].copy() bbox_masks[i] = 1.0
bbox[0] = bbox[0] * 1.0 / img_width if self.learn_empty_box:
bbox[1] = bbox[1] * 1.0 / img_height bbox_masks[i] = 1.0
bbox[2] = bbox[2] * 1.0 / img_width bbox_idx += 1
bbox[3] = bbox[3] * 1.0 / img_height data['bboxes'] = bboxes
td_idx = td_idx_list[cno] data['bbox_masks'] = bbox_masks
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data return data
def encode(self, text, char_or_elem): def _merge_no_span_structure(self, structure):
"""convert text-label into text-index.
""" """
if char_or_elem == "char": This code is refer from:
max_len = self.max_text_length https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
current_dict = self.dict_character """
else: new_structure = []
max_len = self.max_elem_length i = 0
current_dict = self.dict_elem while i < len(structure):
if len(text) > max_len: token = structure[i]
return None if token == '<td>':
if len(text) == 0: token = '<td></td>'
if char_or_elem == "char": i += 1
return [self.dict_character['space']] new_structure.append(token)
else: i += 1
return None return new_structure
text_list = []
for char in text: def _replace_empty_cell_token(self, token_list, cells):
if char not in current_dict: """
return None This fun code is refer from:
text_list.append(current_dict[char]) https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
if len(text_list) == 0: """
if char_or_elem == "char":
return [self.dict_character['space']] bbox_idx = 0
add_empty_bbox_token_list = []
for token in token_list:
if token in ['<td></td>', '<td', '<td>']:
if 'bbox' not in cells[bbox_idx].keys():
content = str(cells[bbox_idx]['tokens'])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
else: else:
return None add_empty_bbox_token_list.append(token)
return text_list return add_empty_bbox_token_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): class TableMasterLabelEncode(TableLabelEncode):
if char_or_elem == "char": """ Convert between text-label and text-index """
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str]) def __init__(self,
elif beg_or_end == "end": max_text_length,
idx = np.array(self.dict_character[self.end_str]) character_dict_path,
else: replace_empty_cell_token=False,
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ merge_no_span_structure=False,
% beg_or_end learn_empty_box=False,
elif char_or_elem == "elem": point_num=2,
if beg_or_end == "beg": **kwargs):
idx = np.array(self.dict_elem[self.beg_str]) super(TableMasterLabelEncode, self).__init__(
elif beg_or_end == "end": max_text_length, character_dict_path, replace_empty_cell_token,
idx = np.array(self.dict_elem[self.end_str]) merge_no_span_structure, learn_empty_box, point_num, **kwargs)
else: self.pad_idx = self.dict[self.pad_str]
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ self.unknown_idx = self.dict[self.unknown_str]
% beg_or_end
else: @property
assert False, "Unsupport type %s in char_or_elem" \ def _max_text_len(self):
% char_or_elem return self.max_text_len
return idx
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
return new_bboxes
class SARLabelEncode(BaseRecLabelEncode): class SARLabelEncode(BaseRecLabelEncode):
...@@ -819,6 +869,7 @@ class VQATokenLabelEncode(object): ...@@ -819,6 +869,7 @@ class VQATokenLabelEncode(object):
contains_re=False, contains_re=False,
add_special_ids=False, add_special_ids=False,
algorithm='LayoutXLM', algorithm='LayoutXLM',
use_textline_bbox_info=True,
infer_mode=False, infer_mode=False,
ocr_engine=None, ocr_engine=None,
**kwargs): **kwargs):
...@@ -847,11 +898,51 @@ class VQATokenLabelEncode(object): ...@@ -847,11 +898,51 @@ class VQATokenLabelEncode(object):
self.add_special_ids = add_special_ids self.add_special_ids = add_special_ids
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
token_bboxes = []
curr_word_idx = 0
x1, y1, x2, y2 = bbox
unit_w = (x2 - x1) / len(text)
for idx, word in enumerate(words):
curr_w = len(word) * unit_w
word_bbox = [x1, y1, x1 + curr_w, y2]
token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
x1 += (len(word) + 1) * unit_w
return token_bboxes
def filter_empty_contents(self, ocr_info):
"""
find out the empty texts and remove the links
"""
new_ocr_info = []
empty_index = []
for idx, info in enumerate(ocr_info):
if len(info["transcription"]) > 0:
new_ocr_info.append(copy.deepcopy(info))
else:
empty_index.append(info["id"])
for idx, info in enumerate(new_ocr_info):
new_link = []
for link in info["linking"]:
if link[0] in empty_index or link[1] in empty_index:
continue
new_link.append(link)
new_ocr_info[idx]["linking"] = new_link
return new_ocr_info
def __call__(self, data): def __call__(self, data):
# load bbox and label info # load bbox and label info
ocr_info = self._load_ocr_info(data) ocr_info = self._load_ocr_info(data)
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)
height, width, _ = data['image'].shape height, width, _ = data['image'].shape
words_list = [] words_list = []
...@@ -863,8 +954,6 @@ class VQATokenLabelEncode(object): ...@@ -863,8 +954,6 @@ class VQATokenLabelEncode(object):
entities = [] entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re: if train_re:
relations = [] relations = []
id2label = {} id2label = {}
...@@ -874,19 +963,24 @@ class VQATokenLabelEncode(object): ...@@ -874,19 +963,24 @@ class VQATokenLabelEncode(object):
data['ocr_info'] = copy.deepcopy(ocr_info) data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info: for info in ocr_info:
text = info["transcription"]
if len(text) <= 0:
continue
if train_re: if train_re:
# for re # for re
if len(info["text"]) == 0: if len(text) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
continue continue
id2label[info["id"]] = info["label"] id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]]) relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box # smooth_box
bbox = self._smooth_box(info["bbox"], height, width) info["bbox"] = self.trans_poly_to_bbox(info["points"])
text = info["text"]
encode_res = self.tokenizer.encode( encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True) text,
pad_to_max_seq_len=False,
return_attention_mask=True,
return_token_type_ids=True)
if not self.add_special_ids: if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove # TODO: use tok.all_special_ids to remove
...@@ -895,6 +989,19 @@ class VQATokenLabelEncode(object): ...@@ -895,6 +989,19 @@ class VQATokenLabelEncode(object):
-1] -1]
encode_res["attention_mask"] = encode_res["attention_mask"][1: encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1] -1]
if self.use_textline_bbox_info:
bbox = [info["bbox"]] * len(encode_res["input_ids"])
else:
bbox = self.split_bbox(info["bbox"], info["transcription"],
self.tokenizer)
if len(bbox) <= 0:
continue
bbox = self._smooth_box(bbox, height, width)
if self.add_special_ids:
bbox.insert(0, [0, 0, 0, 0])
bbox.append([0, 0, 0, 0])
# parse label # parse label
if not self.infer_mode: if not self.infer_mode:
label = info['label'] label = info['label']
...@@ -919,7 +1026,7 @@ class VQATokenLabelEncode(object): ...@@ -919,7 +1026,7 @@ class VQATokenLabelEncode(object):
}) })
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend(bbox)
words_list.append(text) words_list.append(text)
segment_offset_id.append(len(input_ids_list)) segment_offset_id.append(len(input_ids_list))
if not self.infer_mode: if not self.infer_mode:
...@@ -944,40 +1051,42 @@ class VQATokenLabelEncode(object): ...@@ -944,40 +1051,42 @@ class VQATokenLabelEncode(object):
data['entity_id_to_index_map'] = entity_id_to_index_map data['entity_id_to_index_map'] = entity_id_to_index_map
return data return data
def _load_ocr_info(self, data): def trans_poly_to_bbox(self, poly):
def trans_poly_to_bbox(poly): x1 = np.min([p[0] for p in poly])
x1 = np.min([p[0] for p in poly]) x2 = np.max([p[0] for p in poly])
x2 = np.max([p[0] for p in poly]) y1 = np.min([p[1] for p in poly])
y1 = np.min([p[1] for p in poly]) y2 = np.max([p[1] for p in poly])
y2 = np.max([p[1] for p in poly]) return [x1, y1, x2, y2]
return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
if self.infer_mode: if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = [] ocr_info = []
for res in ocr_result: for res in ocr_result:
ocr_info.append({ ocr_info.append({
"text": res[1][0], "transcription": res[1][0],
"bbox": trans_poly_to_bbox(res[0]), "bbox": self.trans_poly_to_bbox(res[0]),
"poly": res[0], "points": res[0],
}) })
return ocr_info return ocr_info
else: else:
info = data['label'] info = data['label']
# read text info # read text info
info_dict = json.loads(info) info_dict = json.loads(info)
return info_dict["ocr_info"] return info_dict
def _smooth_box(self, bbox, height, width): def _smooth_box(self, bboxes, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width) bboxes = np.array(bboxes)
bbox[2] = int(bbox[2] * 1000.0 / width) bboxes[:, 0] = bboxes[:, 0] * 1000 / width
bbox[1] = int(bbox[1] * 1000.0 / height) bboxes[:, 2] = bboxes[:, 2] * 1000 / width
bbox[3] = int(bbox[3] * 1000.0 / height) bboxes[:, 1] = bboxes[:, 1] * 1000 / height
return bbox bboxes[:, 3] = bboxes[:, 3] * 1000 / height
bboxes = bboxes.astype("int64").tolist()
return bboxes
def _parse_label(self, label, encode_res): def _parse_label(self, label, encode_res):
gt_label = [] gt_label = []
if label.lower() == "other": if label.lower() in ["other", "others", "ignore"]:
gt_label.extend([0] * len(encode_res["input_ids"])) gt_label.extend([0] * len(encode_res["input_ids"]))
else: else:
gt_label.append(self.label2id_map[("b-" + label).upper()]) gt_label.append(self.label2id_map[("b-" + label).upper()])
...@@ -1001,7 +1110,6 @@ class MultiLabelEncode(BaseRecLabelEncode): ...@@ -1001,7 +1110,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs) use_space_char, **kwargs)
def __call__(self, data): def __call__(self, data):
data_ctc = copy.deepcopy(data) data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data) data_sar = copy.deepcopy(data)
data_out = dict() data_out = dict()
...@@ -1111,3 +1219,38 @@ class ABINetLabelEncode(BaseRecLabelEncode): ...@@ -1111,3 +1219,38 @@ class ABINetLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character dict_character = ['</s>'] + dict_character
return dict_character return dict_character
class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=True,
**kwargs):
super(SPINLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char)
self.lower = lower
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = [self.beg_str] + [self.end_str] + dict_character
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
target = [0] + text + [1]
padded_text = [0 for _ in range(self.max_text_len + 2)]
padded_text[:len(target)] = target
data['label'] = np.array(padded_text)
return data
...@@ -205,9 +205,12 @@ class DetResizeForTest(object): ...@@ -205,9 +205,12 @@ class DetResizeForTest(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__() super(DetResizeForTest, self).__init__()
self.resize_type = 0 self.resize_type = 0
self.keep_ratio = False
if 'image_shape' in kwargs: if 'image_shape' in kwargs:
self.image_shape = kwargs['image_shape'] self.image_shape = kwargs['image_shape']
self.resize_type = 1 self.resize_type = 1
if 'keep_ratio' in kwargs:
self.keep_ratio = kwargs['keep_ratio']
elif 'limit_side_len' in kwargs: elif 'limit_side_len' in kwargs:
self.limit_side_len = kwargs['limit_side_len'] self.limit_side_len = kwargs['limit_side_len']
self.limit_type = kwargs.get('limit_type', 'min') self.limit_type = kwargs.get('limit_type', 'min')
...@@ -237,6 +240,10 @@ class DetResizeForTest(object): ...@@ -237,6 +240,10 @@ class DetResizeForTest(object):
def resize_image_type1(self, img): def resize_image_type1(self, img):
resize_h, resize_w = self.image_shape resize_h, resize_w = self.image_shape
ori_h, ori_w = img.shape[:2] # (h, w, c) ori_h, ori_w = img.shape[:2] # (h, w, c)
if self.keep_ratio is True:
resize_w = ori_w * resize_h / ori_h
N = math.ceil(resize_w / 32)
resize_w = N * 32
ratio_h = float(resize_h) / ori_h ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w ratio_w = float(resize_w) / ori_w
img = cv2.resize(img, (int(resize_w), int(resize_h))) img = cv2.resize(img, (int(resize_w), int(resize_h)))
......
...@@ -259,6 +259,49 @@ class PRENResizeImg(object): ...@@ -259,6 +259,49 @@ class PRENResizeImg(object):
data['image'] = resized_img.astype(np.float32) data['image'] = resized_img.astype(np.float32)
return data return data
class SPINRecResizeImg(object):
def __init__(self,
image_shape,
interpolation=2,
mean=(127.5, 127.5, 127.5),
std=(127.5, 127.5, 127.5),
**kwargs):
self.image_shape = image_shape
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.interpolation = interpolation
def __call__(self, data):
img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# different interpolation type corresponding the OpenCV
if self.interpolation == 0:
interpolation = cv2.INTER_NEAREST
elif self.interpolation == 1:
interpolation = cv2.INTER_LINEAR
elif self.interpolation == 2:
interpolation = cv2.INTER_CUBIC
elif self.interpolation == 3:
interpolation = cv2.INTER_AREA
else:
raise Exception("Unsupported interpolation type !!!")
# Deal with the image error during image loading
if img is None:
return None
img = cv2.resize(img, tuple(self.image_shape), interpolation)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
# normalize the image
img = img.copy().astype(np.float32)
mean = np.float64(self.mean.reshape(1, -1))
stdinv = 1 / np.float64(self.std.reshape(1, -1))
img -= mean
img *= stdinv
data['image'] = img
return data
class GrayRecResizeImg(object): class GrayRecResizeImg(object):
def __init__(self, def __init__(self,
......
...@@ -32,7 +32,7 @@ class GenTableMask(object): ...@@ -32,7 +32,7 @@ class GenTableMask(object):
self.shrink_h_max = 5 self.shrink_h_max = 5
self.shrink_w_max = 5 self.shrink_w_max = 5
self.mask_type = mask_type self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0): def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影 # 水平投影
projection_map = np.ones_like(erosion) projection_map = np.ones_like(erosion)
...@@ -48,10 +48,12 @@ class GenTableMask(object): ...@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内 in_text = False # 是否遍历到了字符区内
box_list = [] box_list = []
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -70,7 +72,8 @@ class GenTableMask(object): ...@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape h, w = box_gray_img.shape
# 灰度图片进行二值化处理 # 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
cv2.THRESH_BINARY_INV)
# 纵向腐蚀 # 纵向腐蚀
if h < w: if h < w:
kernel = np.ones((2, 1), np.uint8) kernel = np.ones((2, 1), np.uint8)
...@@ -95,10 +98,12 @@ class GenTableMask(object): ...@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = [] box_list = []
spilt_threshold = 0 spilt_threshold = 0
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -120,7 +125,8 @@ class GenTableMask(object): ...@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h h_end = h
word_img = erosion[h_start:h_end + 1, :] word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) w_split_list, w_projection_map = self.projection(word_img.T,
word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1] w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0: if h_start > 0:
h_start -= 1 h_start -= 1
...@@ -170,75 +176,54 @@ class GenTableMask(object): ...@@ -170,75 +176,54 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)): for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno] left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) left, top, right, bottom = self.shrink_bbox(
[left, top, right, bottom])
if self.mask_type == 1: if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0 mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img data['mask_img'] = mask_img
else: else:
mask_img[top:bottom, left:right, :] = (255, 255, 255) mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img data['image'] = mask_img
return data return data
class ResizeTableImage(object): class ResizeTableImage(object):
def __init__(self, max_len, **kwargs): def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
**kwargs):
super(ResizeTableImage, self).__init__() super(ResizeTableImage, self).__init__()
self.max_len = max_len self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def get_img_bbox(self, cells): def __call__(self, data):
bbox_list = [] img = data['image']
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2] height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0) ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio) resize_h = int(height * ratio)
resize_w = int(width * ratio) resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h)) resize_img = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = [] if self.resize_bboxes and not self.infer_mode:
for bno in range(len(bbox_list)): data['bboxes'] = data['bboxes'] * ratio
left, top, right, bottom = bbox_list[bno].copy() data['image'] = resize_img
left = int(left * ratio) data['src_img'] = img
top = int(top * ratio) data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len data['max_len'] = self.max_len
return data return data
class PaddingTableImage(object): class PaddingTableImage(object):
def __init__(self, **kwargs): def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__() super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
max_len = data['max_len'] pad_h, pad_w = self.size
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2] height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy() padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img data['image'] = padding_img
shape = data['shape'].tolist()
shape.extend([pad_h, pad_w])
data['shape'] = np.array(shape)
return data return data
\ No newline at end of file
...@@ -13,7 +13,12 @@ ...@@ -13,7 +13,12 @@
# limitations under the License. # limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .augment import DistortBBox
__all__ = [ __all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation' 'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
'DistortBBox',
] ]
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
import random
class DistortBBox:
def __init__(self, prob=0.5, max_scale=1, **kwargs):
"""Random distort bbox
"""
self.prob = prob
self.max_scale = max_scale
def __call__(self, data):
if random.random() > self.prob:
return data
bbox = np.array(data['bbox'])
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
data['bbox'] = np.clip(data['bbox'], 0, 1000)
data['bbox'] = bbox.tolist()
sys.stdout.flush()
return data
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import random import random
from paddle.io import Dataset from paddle.io import Dataset
import json import json
from copy import deepcopy
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset): ...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path) self.mode = mode.lower()
with open(label_file_path, "rb") as f: logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = f.readlines() self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) # self.check(config['Global']['max_text_length'])
if mode.lower() == "train":
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list] self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
self.logger.warning("{} does not exist!".format(img_path))
continue
if len(structure) == 0 or len(structure) > max_text_length:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines.append(line)
self.data_lines = data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset): ...@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n") data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line) info = json.loads(data_line)
file_name = info['filename'] file_name = info['filename']
select_flag = True cells = info['html']['cells'].copy()
if self.do_hard_select: structure = info['html']['structure']['tokens'].copy()
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1): img_path = os.path.join(self.data_dir, file_name)
select_flag = False if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
if self.table_select_type: data = {
structure = info['html']['structure']['tokens'].copy() 'img_path': img_path,
structure_str = ''.join(structure) 'cells': cells,
table_type = "simple" 'structure': structure,
if 'colspan' in structure_str or 'rowspan' in structure_str: 'file_name': file_name
table_type = "complex" }
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): with open(data['img_path'], 'rb') as f:
select_flag = False img = f.read()
data['image'] = img
if select_flag: outs = transform(data, self.ops)
cells = info['html']['cells'].copy() except:
structure = info['html']['structure'].copy() import traceback
img_path = os.path.join(self.data_dir, file_name) err = traceback.format_exc()
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, e)) data_line, err))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_lines)
...@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss ...@@ -35,6 +35,7 @@ from .rec_sar_loss import SARLoss
from .rec_aster_loss import AsterLoss from .rec_aster_loss import AsterLoss
from .rec_pren_loss import PRENLoss from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss from .rec_multi_loss import MultiLoss
from .rec_spin_att_loss import SPINAttentionLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -51,7 +52,7 @@ from .combined_loss import CombinedLoss ...@@ -51,7 +52,7 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
...@@ -61,7 +62,8 @@ def build_loss(config): ...@@ -61,7 +62,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -57,17 +57,24 @@ class CELoss(nn.Layer): ...@@ -57,17 +57,24 @@ class CELoss(nn.Layer):
class KLJSLoss(object): class KLJSLoss(object):
def __init__(self, mode='kl'): def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS' assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']" ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js": loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply( loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
else:
raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None: elif reduction == "none" or reduction is None:
...@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer): ...@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer):
self.act = None self.act = None
self.use_log = use_log self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="kl")
def _kldiv(self, x, target): def _kldiv(self, x, target):
eps = 1.0e-10 eps = 1.0e-10
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
'''This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR
'''
class SPINAttentionLoss(nn.Layer):
def __init__(self, reduction='mean', ignore_index=-100, **kwargs):
super(SPINAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction=reduction, ignore_index=ignore_index)
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
targets = targets[:, 1:] # remove [eos] in label
label_lengths = batch[2].astype('int64')
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
1], predicts.shape[2]
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
return {'loss': self.loss_func(inputs, targets)}
...@@ -20,15 +20,21 @@ import paddle ...@@ -20,15 +20,21 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer): class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
super(TableAttentionLoss, self).__init__() super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight self.structure_weight = structure_weight
self.loc_weight = loc_weight self.loc_weight = loc_weight
self.use_giou = use_giou self.use_giou = use_giou
self.giou_weight = giou_weight self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
''' '''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
...@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer): ...@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih inters = iw * ih
# union # union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps ) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious # ious
ious = inters / uni ious = inters / uni
...@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer): ...@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs'] structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64") structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:] structure_targets = structure_targets[:, 1:]
if len(batch) == 6: structure_probs = paddle.reshape(structure_probs,
structure_mask = batch[5].astype("int64") [-1, structure_probs.shape[-1]])
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1]) structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets) structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds'] loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32") loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32") loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :] loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :] loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou: if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else: else:
total_loss = structure_loss + loc_loss total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} return {
\ No newline at end of file 'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/tree/master/mmocr/models/textrecog/losses
"""
import paddle
from paddle import nn
class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction='mean')
self.box_loss = nn.L1Loss(reduction='sum')
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
structure_probs = predicts['structure_probs']
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
structure_probs = structure_probs.reshape(
[-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
structure_loss = structure_loss.mean()
losses = dict(structure_loss=structure_loss)
# box loss
bboxes_preds = predicts['loc_preds']
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds = bboxes_preds * bbox_masks
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
masked_bboxes_targets[:, :, 0::2])
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
masked_bboxes_targets[:, :, 1::2])
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
losses.update({
'loss': all_loss,
'horizon_bbox_loss': horizon_loss,
'vertical_bbox_loss': vertical_loss
})
return losses
...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer): ...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch): def forward(self, predicts, batch):
labels = batch[1] labels = batch[5]
attention_mask = batch[4] attention_mask = batch[2]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape( active_outputs = predicts.reshape(
......
...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object): ...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog = "" evaluationLog = ""
# print(len(gt))
for n in range(len(gt)): for n in range(len(gt)):
points = gt[n]['points'] points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore'] dontCare = gt[n]['ignore']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
gtPol = points gtPol = points
...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object): ...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)): for n in range(len(pred)):
points = pred[n]['points'] points = pred[n]['points']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
detPol = points detPol = points
...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object): ...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / ( methodRecall * methodPrecision / (
methodRecall + methodPrecision) methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = { methodMetrics = {
'precision': methodPrecision, 'precision': methodPrecision,
'recall': methodRecall, 'recall': methodRecall,
......
...@@ -12,29 +12,30 @@ ...@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from ppocr.metrics.det_metric import DetMetric
class TableMetric(object): class TableStructureMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5 self.eps = eps
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred_label, batch=None, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy() preds, labels = pred_label
structure_labels = batch[1] pred_structure_batch_list = preds['structure_batch_list']
gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0 correct_num = 0
all_num = 0 all_num = 0
structure_probs = np.argmax(structure_probs, axis=2) for (pred, pred_conf), target in zip(pred_structure_batch_list,
structure_labels = structure_labels[:, 1:] gt_structure_batch_list):
batch_size = structure_probs.shape[0] pred_str = ''.join(pred)
for bno in range(batch_size): target_str = ''.join(target)
all_num += 1 if pred_str == target_str:
if (structure_probs[bno] == structure_labels[bno]).all():
correct_num += 1 correct_num += 1
all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -49,3 +50,89 @@ class TableMetric(object): ...@@ -49,3 +50,89 @@ class TableMetric(object):
def reset(self): def reset(self):
self.correct_num = 0 self.correct_num = 0
self.all_num = 0 self.all_num = 0
self.len_acc_num = 0
self.token_nums = 0
self.anys_dict = dict()
class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
**kwargs):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
self.structure_metric(pred_label)
if self.bbox_metric is not None:
self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
def prepare_bbox_metric_input(self, pred_label):
pred_bbox_batch_list = []
gt_ignore_tags_batch_list = []
gt_bbox_batch_list = []
preds, labels = pred_label
batch_num = len(preds['bbox_batch_list'])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
for pred_box in preds['bbox_batch_list'][batch_idx]
]
pred_bbox_batch_list.append({'points': pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
for gt_box in labels['bbox_batch_list'][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
return [
pred_bbox_batch_list,
[0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
]
def get_metric(self):
structure_metric = self.structure_metric.get_metric()
if self.bbox_metric is None:
return structure_metric
bbox_metric = self.bbox_metric.get_metric()
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
output["structure_metric_{}".format(
sub_key)] = structure_metric[sub_key]
else:
output = structure_metric
for sub_key in bbox_metric:
output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
return output
def reset(self):
self.structure_metric.reset()
if self.bbox_metric is not None:
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
...@@ -37,23 +37,26 @@ class VQAReTokenMetric(object): ...@@ -37,23 +37,26 @@ class VQAReTokenMetric(object):
gt_relations = [] gt_relations = []
for b in range(len(self.relations_list)): for b in range(len(self.relations_list)):
rel_sent = [] rel_sent = []
for head, tail in zip(self.relations_list[b]["head"], if "head" in self.relations_list[b]:
self.relations_list[b]["tail"]): for head, tail in zip(self.relations_list[b]["head"],
rel = {} self.relations_list[b]["tail"]):
rel["head_id"] = head rel = {}
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]], rel["head_id"] = head
self.entities_list[b]["end"][rel["head_id"]]) rel["head"] = (
rel["head_type"] = self.entities_list[b]["label"][rel[ self.entities_list[b]["start"][rel["head_id"]],
"head_id"]] self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
rel["tail_id"] = tail "head_id"]]
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]]) rel["tail_id"] = tail
rel["tail_type"] = self.entities_list[b]["label"][rel[ rel["tail"] = (
"tail_id"]] self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["type"] = 1 rel["tail_type"] = self.entities_list[b]["label"][rel[
rel_sent.append(rel) "tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent) gt_relations.append(rel_sent)
re_metrics = self.re_score( re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries") self.pred_relations_list, gt_relations, mode="boundaries")
......
...@@ -18,9 +18,13 @@ __all__ = ["build_backbone"] ...@@ -18,9 +18,13 @@ __all__ = ["build_backbone"]
def build_backbone(config, model_type): def build_backbone(config, model_type):
if model_type == "det" or model_type == "table": if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3 from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet from .det_resnet import ResNet
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
elif model_type == "rec" or model_type == "cls": elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
...@@ -28,6 +32,7 @@ def build_backbone(config, model_type): ...@@ -28,6 +32,7 @@ def build_backbone(config, model_type):
from .rec_mv1_enhance import MobileNetV1Enhance from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31 from .rec_resnet_31 import ResNet31
from .rec_resnet_32 import ResNet32
from .rec_resnet_45 import ResNet45 from .rec_resnet_45 import ResNet45
from .rec_resnet_aster import ResNet_ASTER from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet from .rec_micronet import MicroNet
...@@ -37,7 +42,7 @@ def build_backbone(config, model_type): ...@@ -37,7 +42,7 @@ def build_backbone(config, model_type):
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR' 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32'
] ]
elif model_type == 'e2e': elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
import math
from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform
from .det_resnet_vd import DeformableConvV2, ConvBNLayer
class BottleneckBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
is_dcn=False):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
act="relu", )
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu",
is_dcn=is_dcn,
dcn_groups=1, )
self.conv2 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
act=None, )
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
stride=stride, )
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=3,
stride=stride,
act="relu")
self.conv1 = ConvBNLayer(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
act=None)
if not shortcut:
self.short = ConvBNLayer(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
stride=stride)
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self,
in_channels=3,
layers=50,
out_indices=None,
dcn_stage=None):
super(ResNet, self).__init__()
self.layers = layers
self.input_image_channel = in_channels
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.dcn_stage = dcn_stage if dcn_stage is not None else [
False, False, False, False
]
self.out_indices = out_indices if out_indices is not None else [
0, 1, 2, 3
]
self.conv = ConvBNLayer(
in_channels=self.input_image_channel,
out_channels=64,
kernel_size=7,
stride=2,
act="relu", )
self.pool2d_max = MaxPool2D(
kernel_size=3,
stride=2,
padding=1, )
self.stages = []
self.out_channels = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list = []
is_dcn = self.dcn_stage[block]
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
is_dcn=is_dcn))
block_list.append(bottleneck_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name,
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
block_list.append(basic_block)
shortcut = True
if block in self.out_indices:
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
out = []
for i, block in enumerate(self.stages):
y = block(y)
if i in self.out_indices:
out.append(y)
return out
...@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D ...@@ -25,7 +25,7 @@ from paddle.vision.ops import DeformConv2D
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from paddle.nn.initializer import Normal, Constant, XavierUniform from paddle.nn.initializer import Normal, Constant, XavierUniform
__all__ = ["ResNet"] __all__ = ["ResNet_vd", "ConvBNLayer", "DeformableConvV2"]
class DeformableConvV2(nn.Layer): class DeformableConvV2(nn.Layer):
...@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer): ...@@ -104,6 +104,7 @@ class ConvBNLayer(nn.Layer):
kernel_size, kernel_size,
stride=1, stride=1,
groups=1, groups=1,
dcn_groups=1,
is_vd_mode=False, is_vd_mode=False,
act=None, act=None,
is_dcn=False): is_dcn=False):
...@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer): ...@@ -128,7 +129,7 @@ class ConvBNLayer(nn.Layer):
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=(kernel_size - 1) // 2, padding=(kernel_size - 1) // 2,
groups=2, #groups, groups=dcn_groups, #groups,
bias_attr=False) bias_attr=False)
self._batch_norm = nn.BatchNorm(out_channels, act=act) self._batch_norm = nn.BatchNorm(out_channels, act=act)
...@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer): ...@@ -162,7 +163,8 @@ class BottleneckBlock(nn.Layer):
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
act='relu', act='relu',
is_dcn=is_dcn) is_dcn=is_dcn,
dcn_groups=2)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
in_channels=out_channels, in_channels=out_channels,
out_channels=out_channels * 4, out_channels=out_channels * 4,
...@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer): ...@@ -238,14 +240,14 @@ class BasicBlock(nn.Layer):
return y return y
class ResNet(nn.Layer): class ResNet_vd(nn.Layer):
def __init__(self, def __init__(self,
in_channels=3, in_channels=3,
layers=50, layers=50,
dcn_stage=None, dcn_stage=None,
out_indices=None, out_indices=None,
**kwargs): **kwargs):
super(ResNet, self).__init__() super(ResNet_vd, self).__init__()
self.layers = layers self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200] supported_layers = [18, 34, 50, 101, 152, 200]
...@@ -321,7 +323,6 @@ class ResNet(nn.Layer): ...@@ -321,7 +323,6 @@ class ResNet(nn.Layer):
for block in range(len(depth)): for block in range(len(depth)):
block_list = [] block_list = []
shortcut = False shortcut = False
# is_dcn = self.dcn_stage[block]
for i in range(depth[block]): for i in range(depth[block]):
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.nn as nn
__all__ = ["ResNet32"]
conv_weight_attr = nn.initializer.KaimingNormal()
class ResNet32(nn.Layer):
"""
Feature Extractor is proposed in FAN Ref [1]
Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
"""
def __init__(self, in_channels, out_channels=512):
"""
Args:
in_channels (int): input channel
output_channel (int): output channel
"""
super(ResNet32, self).__init__()
self.out_channels = out_channels
self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
def forward(self, inputs):
"""
Args:
inputs: input feature
Returns:
output feature
"""
return self.ConvNet(inputs)
class BasicBlock(nn.Layer):
"""Res-net Basic Block"""
expansion = 1
def __init__(self, inplanes, planes,
stride=1, downsample=None,
norm_type='BN', **kwargs):
"""
Args:
inplanes (int): input channel
planes (int): channels of the middle feature
stride (int): stride of the convolution
downsample (int): type of the down_sample
norm_type (str): type of the normalization
**kwargs (None): backup parameter
"""
super(BasicBlock, self).__init__()
self.conv1 = self._conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = self._conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def _conv3x3(self, in_planes, out_planes, stride=1):
"""
Args:
in_planes (int): input channel
out_planes (int): channels of the middle feature
stride (int): stride of the convolution
Returns:
nn.Layer: Conv2D with kernel = 3
"""
return nn.Conv2D(in_planes, out_planes,
kernel_size=3, stride=stride,
padding=1, weight_attr=conv_weight_attr,
bias_attr=False)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Layer):
"""Res-Net network structure"""
def __init__(self, input_channel,
output_channel, block, layers):
"""
Args:
input_channel (int): input channel
output_channel (int): output channel
block (BasicBlock): convolution block
layers (list): layers of the block
"""
super(ResNet, self).__init__()
self.output_channel_block = [int(output_channel / 4),
int(output_channel / 2),
output_channel,
output_channel]
self.inplanes = int(output_channel / 8)
self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
kernel_size=3, stride=1,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
kernel_size=3, stride=1,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn0_2 = nn.BatchNorm2D(self.inplanes)
self.relu = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.layer1 = self._make_layer(block,
self.output_channel_block[0],
layers[0])
self.conv1 = nn.Conv2D(self.output_channel_block[0],
self.output_channel_block[0],
kernel_size=3, stride=1,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.layer2 = self._make_layer(block,
self.output_channel_block[1],
layers[1], stride=1)
self.conv2 = nn.Conv2D(self.output_channel_block[1],
self.output_channel_block[1],
kernel_size=3, stride=1,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False,)
self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
self.maxpool3 = nn.MaxPool2D(kernel_size=2,
stride=(2, 1),
padding=(0, 1))
self.layer3 = self._make_layer(block, self.output_channel_block[2],
layers[2], stride=1)
self.conv3 = nn.Conv2D(self.output_channel_block[2],
self.output_channel_block[2],
kernel_size=3, stride=1,
padding=1,
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
self.layer4 = self._make_layer(block, self.output_channel_block[3],
layers[3], stride=1)
self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
self.output_channel_block[3],
kernel_size=2, stride=(2, 1),
padding=(0, 1),
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
self.output_channel_block[3],
kernel_size=2, stride=1,
padding=0,
weight_attr=conv_weight_attr,
bias_attr=False)
self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
"""
Args:
block (block): convolution block
planes (int): input channels
blocks (list): layers of the block
stride (int): stride of the convolution
Returns:
nn.Sequential: the combination of the convolution block
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride,
weight_attr=conv_weight_attr,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion),
)
layers = list()
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv0_1(x)
x = self.bn0_1(x)
x = self.relu(x)
x = self.conv0_2(x)
x = self.bn0_2(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.layer4(x)
x = self.conv4_1(x)
x = self.bn4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.bn4_2(x)
x = self.relu(x)
return x
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
gcb_ratio = gcb_config['ratio']
gcb_headers = gcb_config['headers']
att_scale = gcb_config['att_scale']
fusion_type = gcb_config['fusion_type']
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
fusion_type=fusion_type)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.gcb_config is not None:
out = self.context_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def get_gcb_config(gcb_config, layer):
if gcb_config is None or not gcb_config['layers'][layer]:
return None
else:
return gcb_config
class TableResNetExtra(nn.Layer):
def __init__(self, layers, in_channels=3, gcb_config=None):
assert len(layers) >= 4
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
in_channels,
64,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer1 = self._make_layer(
BasicBlock,
256,
layers[0],
stride=1,
gcb_config=get_gcb_config(gcb_config, 0))
self.conv3 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer2 = self._make_layer(
BasicBlock,
256,
layers[1],
stride=1,
gcb_config=get_gcb_config(gcb_config, 1))
self.conv4 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer3 = self._make_layer(
BasicBlock,
512,
layers[2],
stride=1,
gcb_config=get_gcb_config(gcb_config, 2))
self.conv5 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
self.layer4 = self._make_layer(
BasicBlock,
512,
layers[3],
stride=1,
gcb_config=get_gcb_config(gcb_config, 3))
self.conv6 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
self.out_channels = [256, 256, 512]
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
gcb_config=gcb_config))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
f = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
f.append(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
f.append(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.layer4(x)
x = self.conv6(x)
x = self.bn6(x)
x = self.relu6(x)
f.append(x)
return f
class MultiAspectGCAttention(nn.Layer):
def __init__(self,
inplanes,
ratio,
headers,
pooling_type='att',
att_scale=False,
fusion_type='channel_add'):
super(MultiAspectGCAttention, self).__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.att_scale = False
self.single_header_inplanes = int(inplanes / headers)
if pooling_type == 'att':
self.conv_mask = nn.Conv2D(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2D(
2 * self.inplanes, self.inplanes, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
def spatial_pool(self, x):
batch, channel, height, width = x.shape
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.reshape([
batch * self.headers, self.single_header_inplanes, height, width
])
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x = input_x.reshape([
batch * self.headers, self.single_header_inplanes,
height * width
])
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
[batch * self.headers, 1, height * width])
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / paddle.sqrt(
self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = paddle.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.reshape(
[batch, self.headers * self.single_header_inplanes, 1, 1])
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = paddle.concat(
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
return out
...@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer): ...@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer):
super(NLPBaseModel, self).__init__() super(NLPBaseModel, self).__init__()
if checkpoints is not None: if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints) self.model = model_class.from_pretrained(checkpoints)
elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
self.model = model_class.from_pretrained(pretrained)
else: else:
pretrained_model_name = pretrained_model_dict[base_model_class] pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained: if pretrained is True:
base_model = base_model_class.from_pretrained( base_model = base_model_class.from_pretrained(
pretrained_model_name) pretrained_model_name)
else: else:
...@@ -74,9 +76,9 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -74,9 +76,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
attention_mask=x[4], attention_mask=x[2],
token_type_ids=x[5], token_type_ids=x[3],
position_ids=None, position_ids=None,
output_hidden_states=False) output_hidden_states=False)
return x return x
...@@ -96,13 +98,15 @@ class LayoutLMv2ForSer(NLPBaseModel): ...@@ -96,13 +98,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -120,13 +124,15 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -120,13 +124,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -140,12 +146,12 @@ class LayoutLMv2ForRe(NLPBaseModel): ...@@ -140,12 +146,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -161,12 +167,12 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -161,12 +167,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -33,6 +33,7 @@ def build_head(config): ...@@ -33,6 +33,7 @@ def build_head(config):
from .rec_aster_head import AsterHead from .rec_aster_head import AsterHead
from .rec_pren_head import PRENHead from .rec_pren_head import PRENHead
from .rec_multi_head import MultiHead from .rec_multi_head import MultiHead
from .rec_spin_att_head import SPINAttentionHead
from .rec_abinet_head import ABINetHead from .rec_abinet_head import ABINetHead
# cls head # cls head
...@@ -42,12 +43,13 @@ def build_head(config): ...@@ -42,12 +43,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead from .table_att_head import TableAttentionHead
from .table_master_head import TableMasterHead
support_dict = [ support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead' 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead'
] ]
#table head #table head
......
...@@ -273,7 +273,8 @@ def _get_length(logit): ...@@ -273,7 +273,8 @@ def _get_length(logit):
out = out.cast('int32') out = out.cast('int32')
out = out.argmax(-1) out = out.argmax(-1)
out = out + 1 out = out + 1
out = paddle.where(abn, out, paddle.to_tensor(logit.shape[1])) len_seq = paddle.zeros_like(out) + logit.shape[1]
out = paddle.where(abn, out, len_seq)
return out return out
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SPINAttentionHead(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(SPINAttentionHead, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = paddle.shape(inputs)[0]
num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
hidden = (paddle.zeros((batch_size, self.hidden_size)),
paddle.zeros((batch_size, self.hidden_size)))
output_hiddens = []
if self.training: # for train
targets = targets[0]
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
char_onehots = None
outputs = None
alpha = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(outputs)
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
if not self.training:
probs = paddle.nn.functional.softmax(probs, axis=2)
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
...@@ -21,6 +21,8 @@ import paddle.nn as nn ...@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
from .rec_att_head import AttentionGRUCell
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, def __init__(self,
...@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer): ...@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer):
hidden_size, hidden_size,
loc_type, loc_type,
in_max_len=488, in_max_len=488,
max_text_length=100, max_text_length=800,
max_elem_length=800, out_channels=30,
max_cell_num=500, point_num=2,
**kwargs): **kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.out_channels = out_channels
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
...@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer): ...@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
...@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer): ...@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer):
output_hiddens = [] output_hiddens = []
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length + 1): for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer): ...@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None elem_onehots = None
outputs = None outputs = None
alpha = None alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length) max_text_length = paddle.to_tensor(self.max_text_length)
i = 0 i = 0
while i < max_elem_length + 1: while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer): ...@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat) loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
"""
import copy
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
class TableMasterHead(nn.Layer):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def __init__(self,
in_channels,
out_channels=30,
headers=8,
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
self.layers = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
self.cls_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.bbox_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
tgt_mask = paddle.logical_and(
trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
out_step, bbox_output_step = self.decode(input, feature, None,
target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
input = paddle.concat(
[input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
return output, bbox_output
def forward_train(self, out_enc, targets):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
src_mask, tgt_mask)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
output = F.softmax(output)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
b, c, h, w = feat.shape
feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
feat = feat.transpose((0, 2, 1))
out_enc = self.positional_encoding(feat)
if self.training:
return self.forward_train(out_enc, targets)
return self.forward_test(out_enc)
class DecoderLayer(nn.Layer):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def __init__(self, headers, d_model, dropout, d_ff):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(headers, d_model, dropout)
self.src_attn = MultiHeadAttention(headers, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](
x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
class MultiHeadAttention(nn.Layer):
def __init__(self, headers, d_model, dropout):
super(MultiHeadAttention, self).__init__()
assert d_model % headers == 0
self.d_k = int(d_model / headers)
self.headers = headers
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
class FeedForward(nn.Layer):
def __init__(self, d_model, d_ff, dropout):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SubLayerConnection(nn.Layer):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SubLayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def masked_fill(x, mask, value):
mask = mask.astype(x.dtype)
return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
def self_attention(query, key, value, mask=None, dropout=None):
"""
Compute 'Scale Dot Product Attention'
"""
d_k = value.shape[-1]
score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
if mask is not None:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score = masked_fill(score, mask == 0, -6.55e4) # for fp16
p_attn = F.softmax(score, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
def clones(module, N):
""" Produce N identical layers """
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, *input):
x = input[0]
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Layer):
""" Implement the PE function. """
def __init__(self, d_model, dropout=0., max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
div_term = paddle.exp(
paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
return self.dropout(feat)
...@@ -105,9 +105,10 @@ class DSConv(nn.Layer): ...@@ -105,9 +105,10 @@ class DSConv(nn.Layer):
class DBFPN(nn.Layer): class DBFPN(nn.Layer):
def __init__(self, in_channels, out_channels, **kwargs): def __init__(self, in_channels, out_channels, use_asf=False, **kwargs):
super(DBFPN, self).__init__() super(DBFPN, self).__init__()
self.out_channels = out_channels self.out_channels = out_channels
self.use_asf = use_asf
weight_attr = paddle.nn.initializer.KaimingUniform() weight_attr = paddle.nn.initializer.KaimingUniform()
self.in2_conv = nn.Conv2D( self.in2_conv = nn.Conv2D(
...@@ -163,6 +164,9 @@ class DBFPN(nn.Layer): ...@@ -163,6 +164,9 @@ class DBFPN(nn.Layer):
weight_attr=ParamAttr(initializer=weight_attr), weight_attr=ParamAttr(initializer=weight_attr),
bias_attr=False) bias_attr=False)
if self.use_asf is True:
self.asf = ASFBlock(self.out_channels, self.out_channels // 4)
def forward(self, x): def forward(self, x):
c2, c3, c4, c5 = x c2, c3, c4, c5 = x
...@@ -187,6 +191,10 @@ class DBFPN(nn.Layer): ...@@ -187,6 +191,10 @@ class DBFPN(nn.Layer):
p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1)
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
if self.use_asf is True:
fuse = self.asf(fuse, [p5, p4, p3, p2])
return fuse return fuse
...@@ -356,3 +364,64 @@ class LKPAN(nn.Layer): ...@@ -356,3 +364,64 @@ class LKPAN(nn.Layer):
fuse = paddle.concat([p5, p4, p3, p2], axis=1) fuse = paddle.concat([p5, p4, p3, p2], axis=1)
return fuse return fuse
class ASFBlock(nn.Layer):
"""
This code is refered from:
https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py
"""
def __init__(self, in_channels, inter_channels, out_features_num=4):
"""
Adaptive Scale Fusion (ASF) block of DBNet++
Args:
in_channels: the number of channels in the input data
inter_channels: the number of middle channels
out_features_num: the number of fused stages
"""
super(ASFBlock, self).__init__()
weight_attr = paddle.nn.initializer.KaimingUniform()
self.in_channels = in_channels
self.inter_channels = inter_channels
self.out_features_num = out_features_num
self.conv = nn.Conv2D(in_channels, inter_channels, 3, padding=1)
self.spatial_scale = nn.Sequential(
#Nx1xHxW
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=3,
bias_attr=False,
padding=1,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.ReLU(),
nn.Conv2D(
in_channels=1,
out_channels=1,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
self.channel_scale = nn.Sequential(
nn.Conv2D(
in_channels=inter_channels,
out_channels=out_features_num,
kernel_size=1,
bias_attr=False,
weight_attr=ParamAttr(initializer=weight_attr)),
nn.Sigmoid())
def forward(self, fuse_features, features_list):
fuse_features = self.conv(fuse_features)
spatial_x = paddle.mean(fuse_features, axis=1, keepdim=True)
attention_scores = self.spatial_scale(spatial_x) + fuse_features
attention_scores = self.channel_scale(attention_scores)
assert len(features_list) == self.out_features_num
out_list = []
for i in range(self.out_features_num):
out_list.append(attention_scores[:, i:i + 1] * features_list[i])
return paddle.concat(out_list, axis=1)
...@@ -47,6 +47,56 @@ class EncoderWithRNN(nn.Layer): ...@@ -47,6 +47,56 @@ class EncoderWithRNN(nn.Layer):
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x
class BidirectionalLSTM(nn.Layer):
def __init__(self, input_size,
hidden_size,
output_size=None,
num_layers=1,
dropout=0,
direction=False,
time_major=False,
with_linear=False):
super(BidirectionalLSTM, self).__init__()
self.with_linear = with_linear
self.rnn = nn.LSTM(input_size,
hidden_size,
num_layers=num_layers,
dropout=dropout,
direction=direction,
time_major=time_major)
# text recognition the specified structure LSTM with linear
if self.with_linear:
self.linear = nn.Linear(hidden_size * 2, output_size)
def forward(self, input_feature):
recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
if self.with_linear:
output = self.linear(recurrent) # batch_size x T x output_size
return output
return recurrent
class EncoderWithCascadeRNN(nn.Layer):
def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
super(EncoderWithCascadeRNN, self).__init__()
self.out_channels = out_channels[-1]
self.encoder = nn.LayerList(
[BidirectionalLSTM(
in_channels if i == 0 else out_channels[i - 1],
hidden_size,
output_size=out_channels[i],
num_layers=1,
direction='bidirectional',
with_linear=with_linear)
for i in range(num_layers)]
)
def forward(self, x):
for i, l in enumerate(self.encoder):
x = l(x)
return x
class EncoderWithFC(nn.Layer): class EncoderWithFC(nn.Layer):
def __init__(self, in_channels, hidden_size): def __init__(self, in_channels, hidden_size):
...@@ -166,13 +216,17 @@ class SequenceEncoder(nn.Layer): ...@@ -166,13 +216,17 @@ class SequenceEncoder(nn.Layer):
'reshape': Im2Seq, 'reshape': Im2Seq,
'fc': EncoderWithFC, 'fc': EncoderWithFC,
'rnn': EncoderWithRNN, 'rnn': EncoderWithRNN,
'svtr': EncoderWithSVTR 'svtr': EncoderWithSVTR,
'cascadernn': EncoderWithCascadeRNN
} }
assert encoder_type in support_encoder_dict, '{} must in {}'.format( assert encoder_type in support_encoder_dict, '{} must in {}'.format(
encoder_type, support_encoder_dict.keys()) encoder_type, support_encoder_dict.keys())
if encoder_type == "svtr": if encoder_type == "svtr":
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, **kwargs) self.encoder_reshape.out_channels, **kwargs)
elif encoder_type == 'cascadernn':
self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size, **kwargs)
else: else:
self.encoder = support_encoder_dict[encoder_type]( self.encoder = support_encoder_dict[encoder_type](
self.encoder_reshape.out_channels, hidden_size) self.encoder_reshape.out_channels, hidden_size)
......
...@@ -18,8 +18,10 @@ __all__ = ['build_transform'] ...@@ -18,8 +18,10 @@ __all__ = ['build_transform']
def build_transform(config): def build_transform(config):
from .tps import TPS from .tps import TPS
from .stn import STN_ON from .stn import STN_ON
from .gaspin_transformer import GA_SPIN_Transformer as GA_SPIN
support_dict = ['TPS', 'STN_ON']
support_dict = ['TPS', 'STN_ON', 'GA_SPIN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
import functools
from .tps import GridGenerator
'''This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/transformations/gaspin_transformation.py
'''
class SP_TransformerNetwork(nn.Layer):
"""
Sturture-Preserving Transformation (SPT) as Equa. (2) in Ref. [1]
Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
"""
def __init__(self, nc=1, default_type=5):
""" Based on SPIN
Args:
nc (int): number of input channels (usually in 1 or 3)
default_type (int): the complexity of transformation intensities (by default set to 6 as the paper)
"""
super(SP_TransformerNetwork, self).__init__()
self.power_list = self.cal_K(default_type)
self.sigmoid = nn.Sigmoid()
self.bn = nn.InstanceNorm2D(nc)
def cal_K(self, k=5):
"""
Args:
k (int): the complexity of transformation intensities (by default set to 6 as the paper)
Returns:
List: the normalized intensity of each pixel in [0,1], denoted as \beta [1x(2K+1)]
"""
from math import log
x = []
if k != 0:
for i in range(1, k+1):
lower = round(log(1-(0.5/(k+1))*i)/log((0.5/(k+1))*i), 2)
upper = round(1/lower, 2)
x.append(lower)
x.append(upper)
x.append(1.00)
return x
def forward(self, batch_I, weights, offsets, lambda_color=None):
"""
Args:
batch_I (Tensor): batch of input images [batch_size x nc x I_height x I_width]
weights:
offsets: the predicted offset by AIN, a scalar
lambda_color: the learnable update gate \alpha in Equa. (5) as
g(x) = (1 - \alpha) \odot x + \alpha \odot x_{offsets}
Returns:
Tensor: transformed images by SPN as Equa. (4) in Ref. [1]
[batch_size x I_channel_num x I_r_height x I_r_width]
"""
batch_I = (batch_I + 1) * 0.5
if offsets is not None:
batch_I = batch_I*(1-lambda_color) + offsets*lambda_color
batch_weight_params = paddle.unsqueeze(paddle.unsqueeze(weights, -1), -1)
batch_I_power = paddle.stack([batch_I.pow(p) for p in self.power_list], axis=1)
batch_weight_sum = paddle.sum(batch_I_power * batch_weight_params, axis=1)
batch_weight_sum = self.bn(batch_weight_sum)
batch_weight_sum = self.sigmoid(batch_weight_sum)
batch_weight_sum = batch_weight_sum * 2 - 1
return batch_weight_sum
class GA_SPIN_Transformer(nn.Layer):
"""
Geometric-Absorbed SPIN Transformation (GA-SPIN) proposed in Ref. [1]
Ref: [1] SPIN: Structure-Preserving Inner Offset Network for Scene Text Recognition. AAAI-2021.
"""
def __init__(self, in_channels=1,
I_r_size=(32, 100),
offsets=False,
norm_type='BN',
default_type=6,
loc_lr=1,
stn=True):
"""
Args:
in_channels (int): channel of input features,
set it to 1 if the grayscale images and 3 if RGB input
I_r_size (tuple): size of rectified images (used in STN transformations)
offsets (bool): set it to False if use SPN w.o. AIN,
and set it to True if use SPIN (both with SPN and AIN)
norm_type (str): the normalization type of the module,
set it to 'BN' by default, 'IN' optionally
default_type (int): the K chromatic space,
set it to 3/5/6 depend on the complexity of transformation intensities
loc_lr (float): learning rate of location network
stn (bool): whther to use stn.
"""
super(GA_SPIN_Transformer, self).__init__()
self.nc = in_channels
self.spt = True
self.offsets = offsets
self.stn = stn # set to True in GA-SPIN, while set it to False in SPIN
self.I_r_size = I_r_size
self.out_channels = in_channels
if norm_type == 'BN':
norm_layer = functools.partial(nn.BatchNorm2D, use_global_stats=True)
elif norm_type == 'IN':
norm_layer = functools.partial(nn.InstanceNorm2D, weight_attr=False,
use_global_stats=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
if self.spt:
self.sp_net = SP_TransformerNetwork(in_channels,
default_type)
self.spt_convnet = nn.Sequential(
# 32*100
nn.Conv2D(in_channels, 32, 3, 1, 1, bias_attr=False),
norm_layer(32), nn.ReLU(),
nn.MaxPool2D(kernel_size=2, stride=2),
# 16*50
nn.Conv2D(32, 64, 3, 1, 1, bias_attr=False),
norm_layer(64), nn.ReLU(),
nn.MaxPool2D(kernel_size=2, stride=2),
# 8*25
nn.Conv2D(64, 128, 3, 1, 1, bias_attr=False),
norm_layer(128), nn.ReLU(),
nn.MaxPool2D(kernel_size=2, stride=2),
# 4*12
)
self.stucture_fc1 = nn.Sequential(
nn.Conv2D(128, 256, 3, 1, 1, bias_attr=False),
norm_layer(256), nn.ReLU(),
nn.MaxPool2D(kernel_size=2, stride=2),
nn.Conv2D(256, 256, 3, 1, 1, bias_attr=False),
norm_layer(256), nn.ReLU(), # 2*6
nn.MaxPool2D(kernel_size=2, stride=2),
nn.Conv2D(256, 512, 3, 1, 1, bias_attr=False),
norm_layer(512), nn.ReLU(), # 1*3
nn.AdaptiveAvgPool2D(1),
nn.Flatten(1, -1), # batch_size x 512
nn.Linear(512, 256, weight_attr=nn.initializer.Normal(0.001)),
nn.BatchNorm1D(256), nn.ReLU()
)
self.out_weight = 2*default_type+1
self.spt_length = 2*default_type+1
if offsets:
self.out_weight += 1
if self.stn:
self.F = 20
self.out_weight += self.F * 2
self.GridGenerator = GridGenerator(self.F*2, self.F)
# self.out_weight*=nc
# Init structure_fc2 in LocalizationNetwork
initial_bias = self.init_spin(default_type*2)
initial_bias = initial_bias.reshape(-1)
param_attr = ParamAttr(
learning_rate=loc_lr,
initializer=nn.initializer.Assign(np.zeros([256, self.out_weight])))
bias_attr = ParamAttr(
learning_rate=loc_lr,
initializer=nn.initializer.Assign(initial_bias))
self.stucture_fc2 = nn.Linear(256, self.out_weight,
weight_attr=param_attr,
bias_attr=bias_attr)
self.sigmoid = nn.Sigmoid()
if offsets:
self.offset_fc1 = nn.Sequential(nn.Conv2D(128, 16,
3, 1, 1,
bias_attr=False),
norm_layer(16),
nn.ReLU(),)
self.offset_fc2 = nn.Conv2D(16, in_channels,
3, 1, 1)
self.pool = nn.MaxPool2D(2, 2)
def init_spin(self, nz):
"""
Args:
nz (int): number of paired \betas exponents, which means the value of K x 2
"""
init_id = [0.00]*nz+[5.00]
if self.offsets:
init_id += [-5.00]
# init_id *=3
init = np.array(init_id)
if self.stn:
F = self.F
ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
initial_bias = initial_bias.reshape(-1)
init = np.concatenate([init, initial_bias], axis=0)
return init
def forward(self, x, return_weight=False):
"""
Args:
x (Tensor): input image batch
return_weight (bool): set to False by default,
if set to True return the predicted offsets of AIN, denoted as x_{offsets}
Returns:
Tensor: rectified image [batch_size x I_channel_num x I_height x I_width], the same as the input size
"""
if self.spt:
feat = self.spt_convnet(x)
fc1 = self.stucture_fc1(feat)
sp_weight_fusion = self.stucture_fc2(fc1)
sp_weight_fusion = sp_weight_fusion.reshape([x.shape[0], self.out_weight, 1])
if self.offsets: # SPIN w. AIN
lambda_color = sp_weight_fusion[:, self.spt_length, 0]
lambda_color = self.sigmoid(lambda_color).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
sp_weight = sp_weight_fusion[:, :self.spt_length, :]
offsets = self.pool(self.offset_fc2(self.offset_fc1(feat)))
assert offsets.shape[2] == 2 # 2
assert offsets.shape[3] == 6 # 16
offsets = self.sigmoid(offsets) # v12
if return_weight:
return offsets
offsets = nn.functional.upsample(offsets, size=(x.shape[2], x.shape[3]), mode='bilinear')
if self.stn:
batch_C_prime = sp_weight_fusion[:, (self.spt_length + 1):, :].reshape([x.shape[0], self.F, 2])
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
self.I_r_size[0],
self.I_r_size[1],
2])
else: # SPIN w.o. AIN
sp_weight = sp_weight_fusion[:, :self.spt_length, :]
lambda_color, offsets = None, None
if self.stn:
batch_C_prime = sp_weight_fusion[:, self.spt_length:, :].reshape([x.shape[0], self.F, 2])
build_P_prime = self.GridGenerator(batch_C_prime, self.I_r_size)
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.shape[0],
self.I_r_size[0],
self.I_r_size[1],
2])
x = self.sp_net(x, sp_weight, offsets, lambda_color)
if self.stn:
x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
return x
...@@ -308,3 +308,81 @@ class Const(object): ...@@ -308,3 +308,81 @@ class Const(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class DecayLearningRate(object):
"""
DecayLearningRate learning rate decay
new_lr = (lr - end_lr) * (1 - epoch/decay_steps)**power + end_lr
Args:
learning_rate(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
factor(float): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 0.9
end_lr(float): The minimum final learning rate. Default: 0.0.
"""
def __init__(self,
learning_rate,
step_each_epoch,
epochs,
factor=0.9,
end_lr=0,
**kwargs):
super(DecayLearningRate, self).__init__()
self.learning_rate = learning_rate
self.epochs = epochs + 1
self.factor = factor
self.end_lr = 0
self.decay_steps = step_each_epoch * epochs
def __call__(self):
learning_rate = lr.PolynomialDecay(
learning_rate=self.learning_rate,
decay_steps=self.decay_steps,
power=self.factor,
end_lr=self.end_lr)
return learning_rate
class MultiStepDecay(object):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
milestones,
step_each_epoch,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
...@@ -26,12 +26,14 @@ from .east_postprocess import EASTPostProcess ...@@ -26,12 +26,14 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -42,7 +44,8 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +44,8 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode' 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -38,6 +38,7 @@ class DBPostProcess(object): ...@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
...@@ -45,6 +46,7 @@ class DBPostProcess(object): ...@@ -45,6 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode self.score_mode = score_mode
self.use_polygon = use_polygon
assert score_mode in [ assert score_mode in [
"slow", "fast" "slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode) ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
...@@ -52,6 +54,53 @@ class DBPostProcess(object): ...@@ -52,6 +54,53 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
...@@ -85,7 +134,7 @@ class DBPostProcess(object): ...@@ -85,7 +134,7 @@ class DBPostProcess(object):
if self.box_thresh > score: if self.box_thresh > score:
continue continue
box = self.unclip(points).reshape(-1, 1, 2) box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box) box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2: if sside < self.min_size + 2:
continue continue
...@@ -99,8 +148,7 @@ class DBPostProcess(object): ...@@ -99,8 +148,7 @@ class DBPostProcess(object):
scores.append(score) scores.append(score)
return np.array(boxes, dtype=np.int16), scores return np.array(boxes, dtype=np.int16), scores
def unclip(self, box): def unclip(self, box, unclip_ratio):
unclip_ratio = self.unclip_ratio
poly = Polygon(box) poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset() offset = pyclipper.PyclipperOffset()
...@@ -185,8 +233,12 @@ class DBPostProcess(object): ...@@ -185,8 +233,12 @@ class DBPostProcess(object):
self.dilation_kernel) self.dilation_kernel)
else: else:
mask = segmentation[batch_index] mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, if self.use_polygon is True:
src_w, src_h) boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
else:
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object): ...@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5, unclip_ratio=1.5,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
...@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object): ...@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object):
max_candidates=max_candidates, max_candidates=max_candidates,
unclip_ratio=unclip_ratio, unclip_ratio=unclip_ratio,
use_dilation=use_dilation, use_dilation=use_dilation,
score_mode=score_mode) score_mode=score_mode,
use_polygon=use_polygon)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
...@@ -58,6 +58,8 @@ class PSEPostProcess(object): ...@@ -58,6 +58,8 @@ class PSEPostProcess(object):
kernels = (pred > self.thresh).astype('float32') kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :] text_mask = kernels[:, 0, :, :]
text_mask = paddle.unsqueeze(text_mask, axis=1)
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy() score = score.numpy()
......
...@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode): class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
...@@ -807,3 +667,18 @@ class ABINetLabelDecode(NRTRLabelDecode): ...@@ -807,3 +667,18 @@ class ABINetLabelDecode(NRTRLabelDecode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character dict_character = ['</s>'] + dict_character
return dict_character return dict_character
class SPINLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SPINLabelDecode, self).__init__(character_dict_path,
use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + [self.end_str] + dict_character
return dict_character
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
bbox_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
bbox_preds = bbox_preds.numpy()
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index.
"""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index.
"""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
], 'The shape used for box normalization must be ori or pad'
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
def get_ignored_tokens(self):
pad_idx = self.dict[self.pad_str]
start_idx = self.dict[self.beg_str]
end_idx = self.dict[self.end_str]
unknown_idx = self.dict[self.unknown_str]
return [start_idx, end_idx, pad_idx, unknown_idx]
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
if self.box_shape == 'pad':
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
return bbox
...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs): def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[0]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
if batch is not None: if batch is not None:
return self._metric(preds, batch[1]) return self._metric(preds, batch[5])
else: else:
return self._infer(preds, **kwargs) return self._infer(preds, **kwargs)
...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]]) j]])
return decode_out_list, label_decode_out_list return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): def _infer(self, preds, segment_offset_ids, ocr_infos):
results = [] results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip( for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
preds, attention_masks, segment_offset_ids, ocr_infos): ocr_infos):
pred = np.argmax(pred, axis=1) pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred] pred = [self.id2label_map[idx] for idx in pred]
......
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
:
(
'
-
,
%
>
.
[
?
)
"
=
_
*
]
;
&
+
$
@
/
|
!
<
#
`
{
~
\
}
^
\ No newline at end of file
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
277 28 1267 1186
<b>
V
a
r
i
b
l
e
</b>
H
z
d
t
o
9
5
%
C
I
<i>
p
</i>
v
u
*
A
g
(
m
n
)
0
.
7
1
6
>
8
3
2
G
4
M
F
T
y
f
s
L
w
c
U
h
D
S
Q
R
x
P
-
E
O
/
k
,
+
N
K
q
[
]
<
<sup>
</sup>
μ
±
J
j
W
_
Δ
B
:
Y
α
λ
;
<sub>
</sub>
?
=
°
#
̊
̈
̂
Z
X
β
'
~
@
"
γ
&
χ
σ
§
|
×
$
\
π
®
^
<underline>
</underline>
́
·
£
φ
Ψ
ß
η
̃
Φ
ρ
̄
δ
̧
Ω
{
}
̀
ø
κ
ε
¥
`
ω
Σ
Β
̸
Χ
Α
ψ
ǂ
ζ
!
Γ
θ
υ
τ
Ø
©
С
˂
ɛ
¢
˃
­
Π
̌
<overline>
</overline>
¤
̆
ξ
÷

ι
ν
<strike>
</strike>
«
»
ł
ı
Θ
̇
æ
ʹ
ˆ
̨
Ι
Λ
А
<thead> <thead>
<tr> <tr>
<td> <td>
...@@ -303,2457 +25,4 @@ $ ...@@ -303,2457 +25,4 @@ $
rowspan="8" rowspan="8"
rowspan="6" rowspan="6"
rowspan="7" rowspan="7"
rowspan="10" rowspan="10"
0 2924682 \ No newline at end of file
1 3405345
2 2363468
3 2709165
4 4078680
5 3250792
6 1923159
7 1617890
8 1450532
9 1717624
10 1477550
11 1489223
12 915528
13 819193
14 593660
15 518924
16 682065
17 494584
18 400591
19 396421
20 340994
21 280688
22 250328
23 226786
24 199927
25 182707
26 164629
27 141613
28 127554
29 116286
30 107682
31 96367
32 88002
33 79234
34 72186
35 65921
36 60374
37 55976
38 52166
39 47414
40 44932
41 41279
42 38232
43 35463
44 33703
45 30557
46 29639
47 27000
48 25447
49 23186
50 22093
51 20412
52 19844
53 18261
54 17561
55 16499
56 15597
57 14558
58 14372
59 13445
60 13514
61 12058
62 11145
63 10767
64 10370
65 9630
66 9337
67 8881
68 8727
69 8060
70 7994
71 7740
72 7189
73 6729
74 6749
75 6548
76 6321
77 5957
78 5740
79 5407
80 5370
81 5035
82 4921
83 4656
84 4600
85 4519
86 4277
87 4023
88 3939
89 3910
90 3861
91 3560
92 3483
93 3406
94 3346
95 3229
96 3122
97 3086
98 3001
99 2884
100 2822
101 2677
102 2670
103 2610
104 2452
105 2446
106 2400
107 2300
108 2316
109 2196
110 2089
111 2083
112 2041
113 1881
114 1838
115 1896
116 1795
117 1786
118 1743
119 1765
120 1750
121 1683
122 1563
123 1499
124 1513
125 1462
126 1388
127 1441
128 1417
129 1392
130 1306
131 1321
132 1274
133 1294
134 1240
135 1126
136 1157
137 1130
138 1084
139 1130
140 1083
141 1040
142 980
143 1031
144 974
145 980
146 932
147 898
148 960
149 907
150 852
151 912
152 859
153 847
154 876
155 792
156 791
157 765
158 788
159 787
160 744
161 673
162 683
163 697
164 666
165 680
166 632
167 677
168 657
169 618
170 587
171 585
172 567
173 549
174 562
175 548
176 542
177 539
178 542
179 549
180 547
181 526
182 525
183 514
184 512
185 505
186 515
187 467
188 475
189 458
190 435
191 443
192 427
193 424
194 404
195 389
196 429
197 404
198 386
199 351
200 388
201 408
202 361
203 346
204 324
205 361
206 363
207 364
208 323
209 336
210 342
211 315
212 325
213 328
214 314
215 327
216 320
217 300
218 295
219 315
220 310
221 295
222 275
223 248
224 274
225 232
226 293
227 259
228 286
229 263
230 242
231 214
232 261
233 231
234 211
235 250
236 233
237 206
238 224
239 210
240 233
241 223
242 216
243 222
244 207
245 212
246 196
247 205
248 201
249 202
250 211
251 201
252 215
253 179
254 163
255 179
256 191
257 188
258 196
259 150
260 154
261 176
262 211
263 166
264 171
265 165
266 149
267 182
268 159
269 161
270 164
271 161
272 141
273 151
274 127
275 129
276 142
277 158
278 148
279 135
280 127
281 134
282 138
283 131
284 126
285 125
286 130
287 126
288 135
289 125
290 135
291 131
292 95
293 135
294 106
295 117
296 136
297 128
298 128
299 118
300 109
301 112
302 117
303 108
304 120
305 100
306 95
307 108
308 112
309 77
310 120
311 104
312 109
313 89
314 98
315 82
316 98
317 93
318 77
319 93
320 77
321 98
322 93
323 86
324 89
325 73
326 70
327 71
328 77
329 87
330 77
331 93
332 100
333 83
334 72
335 74
336 69
337 77
338 68
339 78
340 90
341 98
342 75
343 80
344 63
345 71
346 83
347 66
348 71
349 70
350 62
351 62
352 59
353 63
354 62
355 52
356 64
357 64
358 56
359 49
360 57
361 63
362 60
363 68
364 62
365 55
366 54
367 40
368 75
369 70
370 53
371 58
372 57
373 55
374 69
375 57
376 53
377 43
378 45
379 47
380 56
381 51
382 59
383 51
384 43
385 34
386 57
387 49
388 39
389 46
390 48
391 43
392 40
393 54
394 50
395 41
396 43
397 33
398 27
399 49
400 44
401 44
402 38
403 30
404 32
405 37
406 39
407 42
408 53
409 39
410 34
411 31
412 32
413 52
414 27
415 41
416 34
417 36
418 50
419 35
420 32
421 33
422 45
423 35
424 40
425 29
426 41
427 40
428 39
429 32
430 31
431 34
432 29
433 27
434 26
435 22
436 34
437 28
438 30
439 38
440 35
441 36
442 36
443 27
444 24
445 33
446 31
447 25
448 33
449 27
450 32
451 46
452 31
453 35
454 35
455 34
456 26
457 21
458 25
459 26
460 24
461 27
462 33
463 30
464 35
465 21
466 32
467 19
468 27
469 16
470 28
471 26
472 27
473 26
474 25
475 25
476 27
477 20
478 28
479 22
480 23
481 16
482 25
483 27
484 19
485 23
486 19
487 15
488 15
489 23
490 24
491 19
492 20
493 18
494 17
495 30
496 28
497 20
498 29
499 17
500 19
501 21
502 15
503 24
504 15
505 19
506 25
507 16
508 23
509 26
510 21
511 15
512 12
513 16
514 18
515 24
516 26
517 18
518 8
519 25
520 14
521 8
522 24
523 20
524 18
525 15
526 13
527 17
528 18
529 22
530 21
531 9
532 16
533 17
534 13
535 17
536 15
537 13
538 20
539 13
540 19
541 29
542 10
543 8
544 18
545 13
546 9
547 18
548 10
549 18
550 18
551 9
552 9
553 15
554 13
555 15
556 14
557 14
558 18
559 8
560 13
561 9
562 7
563 12
564 6
565 9
566 9
567 18
568 9
569 10
570 13
571 14
572 13
573 21
574 8
575 16
576 12
577 9
578 16
579 17
580 22
581 6
582 14
583 13
584 15
585 11
586 13
587 5
588 12
589 13
590 15
591 13
592 15
593 12
594 7
595 18
596 12
597 13
598 13
599 13
600 12
601 12
602 10
603 11
604 6
605 6
606 2
607 9
608 8
609 12
610 9
611 12
612 13
613 12
614 14
615 9
616 8
617 9
618 14
619 13
620 12
621 6
622 8
623 8
624 8
625 12
626 8
627 7
628 5
629 8
630 12
631 6
632 10
633 10
634 7
635 8
636 9
637 6
638 9
639 4
640 12
641 4
642 3
643 11
644 10
645 6
646 12
647 12
648 4
649 4
650 9
651 8
652 6
653 5
654 14
655 10
656 11
657 8
658 5
659 5
660 9
661 13
662 4
663 5
664 9
665 11
666 12
667 7
668 13
669 2
670 1
671 7
672 7
673 7
674 10
675 9
676 6
677 5
678 7
679 6
680 3
681 3
682 4
683 9
684 8
685 5
686 3
687 11
688 9
689 2
690 6
691 5
692 9
693 5
694 6
695 5
696 9
697 8
698 3
699 7
700 5
701 9
702 8
703 7
704 2
705 3
706 7
707 6
708 6
709 10
710 2
711 10
712 6
713 7
714 5
715 6
716 4
717 6
718 8
719 4
720 6
721 7
722 5
723 7
724 3
725 10
726 10
727 3
728 7
729 7
730 5
731 2
732 1
733 5
734 1
735 5
736 6
737 2
738 2
739 3
740 7
741 2
742 7
743 4
744 5
745 4
746 5
747 3
748 1
749 4
750 4
751 2
752 4
753 6
754 6
755 6
756 3
757 2
758 5
759 5
760 3
761 4
762 2
763 1
764 8
765 3
766 4
767 3
768 1
769 5
770 3
771 3
772 4
773 4
774 1
775 3
776 2
777 2
778 3
779 3
780 1
781 4
782 3
783 4
784 6
785 3
786 5
787 4
788 2
789 4
790 5
791 4
792 6
794 4
795 1
796 1
797 4
798 2
799 3
800 3
801 1
802 5
803 5
804 3
805 3
806 3
807 4
808 4
809 2
811 5
812 4
813 6
814 3
815 2
816 2
817 3
818 5
819 3
820 1
821 1
822 4
823 3
824 4
825 8
826 3
827 5
828 5
829 3
830 6
831 3
832 4
833 8
834 5
835 3
836 3
837 2
838 4
839 2
840 1
841 3
842 2
843 1
844 3
846 4
847 4
848 3
849 3
850 2
851 3
853 1
854 4
855 4
856 2
857 4
858 1
859 2
860 5
861 1
862 1
863 4
864 2
865 2
867 5
868 1
869 4
870 1
871 1
872 1
873 2
875 5
876 3
877 1
878 3
879 3
880 3
881 2
882 1
883 6
884 2
885 2
886 1
887 1
888 3
889 2
890 2
891 3
892 1
893 3
894 1
895 5
896 1
897 3
899 2
900 2
902 1
903 2
904 4
905 4
906 3
907 1
908 1
909 2
910 5
911 2
912 3
914 1
915 1
916 2
918 2
919 2
920 4
921 4
922 1
923 1
924 4
925 5
926 1
928 2
929 1
930 1
931 1
932 1
933 1
934 2
935 1
936 1
937 1
938 2
939 1
941 1
942 4
944 2
945 2
946 2
947 1
948 1
950 1
951 2
953 1
954 2
955 1
956 1
957 2
958 1
960 3
962 4
963 1
964 1
965 3
966 2
967 2
968 1
969 3
970 3
972 1
974 4
975 3
976 3
977 2
979 2
980 1
981 1
983 5
984 1
985 3
986 1
987 2
988 4
989 2
991 2
992 2
993 1
994 1
996 2
997 2
998 1
999 3
1000 2
1001 1
1002 3
1003 3
1004 2
1005 3
1006 1
1007 2
1009 1
1011 1
1013 3
1014 1
1016 2
1017 1
1018 1
1019 1
1020 4
1021 1
1022 2
1025 1
1026 1
1027 2
1028 1
1030 1
1031 2
1032 4
1034 3
1035 2
1036 1
1038 1
1039 1
1040 1
1041 1
1042 2
1043 1
1044 2
1045 4
1048 1
1050 1
1051 1
1052 2
1054 1
1055 3
1056 2
1057 1
1059 1
1061 2
1063 1
1064 1
1065 1
1066 1
1067 1
1068 1
1069 2
1074 1
1075 1
1077 1
1078 1
1079 1
1082 1
1085 1
1088 1
1090 1
1091 1
1092 2
1094 2
1097 2
1098 1
1099 2
1101 2
1102 1
1104 1
1105 1
1107 1
1109 1
1111 2
1112 1
1114 2
1115 2
1116 2
1117 1
1118 1
1119 1
1120 1
1122 1
1123 1
1127 1
1128 3
1132 2
1138 3
1142 1
1145 4
1150 1
1153 2
1154 1
1158 1
1159 1
1163 1
1165 1
1169 2
1174 1
1176 1
1177 1
1178 2
1179 1
1180 2
1181 1
1182 1
1183 2
1185 1
1187 1
1191 2
1193 1
1195 3
1196 1
1201 3
1203 1
1206 1
1210 1
1213 1
1214 1
1215 2
1218 1
1220 1
1221 1
1225 1
1226 1
1233 2
1241 1
1243 1
1249 1
1250 2
1251 1
1254 1
1255 2
1260 1
1268 1
1270 1
1273 1
1274 1
1277 1
1284 1
1287 1
1291 1
1292 2
1294 1
1295 2
1297 1
1298 1
1301 1
1307 1
1308 3
1311 2
1313 1
1316 1
1321 1
1324 1
1325 1
1330 1
1333 1
1334 1
1338 2
1340 1
1341 1
1342 1
1343 1
1345 1
1355 1
1357 1
1360 2
1375 1
1376 1
1380 1
1383 1
1387 1
1389 1
1393 1
1394 1
1396 1
1398 1
1410 1
1414 1
1419 1
1425 1
1434 1
1435 1
1438 1
1439 1
1447 1
1455 2
1460 1
1461 1
1463 1
1466 1
1470 1
1473 1
1478 1
1480 1
1483 1
1484 1
1485 2
1492 2
1499 1
1509 1
1512 1
1513 1
1523 1
1524 1
1525 2
1529 1
1539 1
1544 1
1568 1
1584 1
1591 1
1598 1
1600 1
1604 1
1614 1
1617 1
1621 1
1622 1
1626 1
1638 1
1648 1
1658 1
1661 1
1679 1
1682 1
1693 1
1700 1
1705 1
1707 1
1722 1
1728 1
1758 1
1762 1
1763 1
1775 1
1776 1
1801 1
1810 1
1812 1
1827 1
1834 1
1846 1
1847 1
1848 1
1851 1
1862 1
1866 1
1877 2
1884 1
1888 1
1903 1
1912 1
1925 1
1938 1
1955 1
1998 1
2054 1
2058 1
2065 1
2069 1
2076 1
2089 1
2104 1
2111 1
2133 1
2138 1
2156 1
2204 1
2212 1
2237 1
2246 2
2298 1
2304 1
2360 1
2400 1
2481 1
2544 1
2586 1
2622 1
2666 1
2682 1
2725 1
2920 1
3997 1
4019 1
5211 1
12 19
14 1
16 401
18 2
20 421
22 557
24 625
26 50
28 4481
30 52
32 550
34 5840
36 4644
38 87
40 5794
41 33
42 571
44 11805
46 4711
47 7
48 597
49 12
50 678
51 2
52 14715
53 3
54 7322
55 3
56 508
57 39
58 3486
59 11
60 8974
61 45
62 1276
63 4
64 15693
65 15
66 657
67 13
68 6409
69 10
70 3188
71 25
72 1889
73 27
74 10370
75 9
76 12432
77 23
78 520
79 15
80 1534
81 29
82 2944
83 23
84 12071
85 36
86 1502
87 10
88 10978
89 11
90 889
91 16
92 4571
93 17
94 7855
95 21
96 2271
97 33
98 1423
99 15
100 11096
101 21
102 4082
103 13
104 5442
105 25
106 2113
107 26
108 3779
109 43
110 1294
111 29
112 7860
113 29
114 4965
115 22
116 7898
117 25
118 1772
119 28
120 1149
121 38
122 1483
123 32
124 10572
125 25
126 1147
127 31
128 1699
129 22
130 5533
131 22
132 4669
133 34
134 3777
135 10
136 5412
137 21
138 855
139 26
140 2485
141 46
142 1970
143 27
144 6565
145 40
146 933
147 15
148 7923
149 16
150 735
151 23
152 1111
153 33
154 3714
155 27
156 2445
157 30
158 3367
159 10
160 4646
161 27
162 990
163 23
164 5679
165 25
166 2186
167 17
168 899
169 32
170 1034
171 22
172 6185
173 32
174 2685
175 17
176 1354
177 38
178 1460
179 15
180 3478
181 20
182 958
183 20
184 6055
185 23
186 2180
187 15
188 1416
189 30
190 1284
191 22
192 1341
193 21
194 2413
195 18
196 4984
197 13
198 830
199 22
200 1834
201 19
202 2238
203 9
204 3050
205 22
206 616
207 17
208 2892
209 22
210 711
211 30
212 2631
213 19
214 3341
215 21
216 987
217 26
218 823
219 9
220 3588
221 20
222 692
223 7
224 2925
225 31
226 1075
227 16
228 2909
229 18
230 673
231 20
232 2215
233 14
234 1584
235 21
236 1292
237 29
238 1647
239 25
240 1014
241 30
242 1648
243 19
244 4465
245 10
246 787
247 11
248 480
249 25
250 842
251 15
252 1219
253 23
254 1508
255 8
256 3525
257 16
258 490
259 12
260 1678
261 14
262 822
263 16
264 1729
265 28
266 604
267 11
268 2572
269 7
270 1242
271 15
272 725
273 18
274 1983
275 13
276 1662
277 19
278 491
279 12
280 1586
281 14
282 563
283 10
284 2363
285 10
286 656
287 14
288 725
289 28
290 871
291 9
292 2606
293 12
294 961
295 9
296 478
297 13
298 1252
299 10
300 736
301 19
302 466
303 13
304 2254
305 12
306 486
307 14
308 1145
309 13
310 955
311 13
312 1235
313 13
314 931
315 14
316 1768
317 11
318 330
319 10
320 539
321 23
322 570
323 12
324 1789
325 13
326 884
327 5
328 1422
329 14
330 317
331 11
332 509
333 13
334 1062
335 12
336 577
337 27
338 378
339 10
340 2313
341 9
342 391
343 13
344 894
345 17
346 664
347 9
348 453
349 6
350 363
351 15
352 1115
353 13
354 1054
355 8
356 1108
357 12
358 354
359 7
360 363
361 16
362 344
363 11
364 1734
365 12
366 265
367 10
368 969
369 16
370 316
371 12
372 757
373 7
374 563
375 15
376 857
377 9
378 469
379 9
380 385
381 12
382 921
383 15
384 764
385 14
386 246
387 6
388 1108
389 14
390 230
391 8
392 266
393 11
394 641
395 8
396 719
397 9
398 243
399 4
400 1108
401 7
402 229
403 7
404 903
405 7
406 257
407 12
408 244
409 3
410 541
411 6
412 744
413 8
414 419
415 8
416 388
417 19
418 470
419 14
420 612
421 6
422 342
423 3
424 1179
425 3
426 116
427 14
428 207
429 6
430 255
431 4
432 288
433 12
434 343
435 6
436 1015
437 3
438 538
439 10
440 194
441 6
442 188
443 15
444 524
445 7
446 214
447 7
448 574
449 6
450 214
451 5
452 635
453 9
454 464
455 5
456 205
457 9
458 163
459 2
460 558
461 4
462 171
463 14
464 444
465 11
466 543
467 5
468 388
469 6
470 141
471 4
472 647
473 3
474 210
475 4
476 193
477 7
478 195
479 7
480 443
481 10
482 198
483 3
484 816
485 6
486 128
487 9
488 215
489 9
490 328
491 7
492 158
493 11
494 335
495 8
496 435
497 6
498 174
499 1
500 373
501 5
502 140
503 7
504 330
505 9
506 149
507 5
508 642
509 3
510 179
511 3
512 159
513 8
514 204
515 7
516 306
517 4
518 110
519 5
520 326
521 6
522 305
523 6
524 294
525 7
526 268
527 5
528 149
529 4
530 133
531 2
532 513
533 10
534 116
535 5
536 258
537 4
538 113
539 4
540 138
541 6
542 116
544 485
545 4
546 93
547 9
548 299
549 3
550 256
551 6
552 92
553 3
554 175
555 6
556 253
557 7
558 95
559 2
560 128
561 4
562 206
563 2
564 465
565 3
566 69
567 3
568 157
569 7
570 97
571 8
572 118
573 5
574 130
575 4
576 301
577 6
578 177
579 2
580 397
581 3
582 80
583 1
584 128
585 5
586 52
587 2
588 72
589 1
590 84
591 6
592 323
593 11
594 77
595 5
596 205
597 1
598 244
599 4
600 69
601 3
602 89
603 5
604 254
605 6
606 147
607 3
608 83
609 3
610 77
611 3
612 194
613 1
614 98
615 3
616 243
617 3
618 50
619 8
620 188
621 4
622 67
623 4
624 123
625 2
626 50
627 1
628 239
629 2
630 51
631 4
632 65
633 5
634 188
636 81
637 3
638 46
639 3
640 103
641 1
642 136
643 3
644 188
645 3
646 58
648 122
649 4
650 47
651 2
652 155
653 4
654 71
655 1
656 71
657 3
658 50
659 2
660 177
661 5
662 66
663 2
664 183
665 3
666 50
667 2
668 53
669 2
670 115
672 66
673 2
674 47
675 1
676 197
677 2
678 46
679 3
680 95
681 3
682 46
683 3
684 107
685 1
686 86
687 2
688 158
689 4
690 51
691 1
692 80
694 56
695 4
696 40
698 43
699 3
700 95
701 2
702 51
703 2
704 133
705 1
706 100
707 2
708 121
709 2
710 15
711 3
712 35
713 2
714 20
715 3
716 37
717 2
718 78
720 55
721 1
722 42
723 2
724 218
725 3
726 23
727 2
728 26
729 1
730 64
731 2
732 65
734 24
735 2
736 53
737 1
738 32
739 1
740 60
742 81
743 1
744 77
745 1
746 47
747 1
748 62
749 1
750 19
751 1
752 86
753 3
754 40
756 55
757 2
758 38
759 1
760 101
761 1
762 22
764 67
765 2
766 35
767 1
768 38
769 1
770 22
771 1
772 82
773 1
774 73
776 29
777 1
778 55
780 23
781 1
782 16
784 84
785 3
786 28
788 59
789 1
790 33
791 3
792 24
794 13
795 1
796 110
797 2
798 15
800 22
801 3
802 29
803 1
804 87
806 21
808 29
810 48
812 28
813 1
814 58
815 1
816 48
817 1
818 31
819 1
820 66
822 17
823 2
824 58
826 10
827 2
828 25
829 1
830 29
831 1
832 63
833 1
834 26
835 3
836 52
837 1
838 18
840 27
841 2
842 12
843 1
844 83
845 1
846 7
847 1
848 10
850 26
852 25
853 1
854 15
856 27
858 32
859 1
860 15
862 43
864 32
865 1
866 6
868 39
870 11
872 25
873 1
874 10
875 1
876 20
877 2
878 19
879 1
880 30
882 11
884 53
886 25
887 1
888 28
890 6
892 36
894 10
896 13
898 14
900 31
902 14
903 2
904 43
906 25
908 9
910 11
911 1
912 16
913 1
914 24
916 27
918 6
920 15
922 27
923 1
924 23
926 13
928 42
929 1
930 3
932 27
934 17
936 8
937 1
938 11
940 33
942 4
943 1
944 18
946 15
948 13
950 18
952 12
954 11
956 21
958 10
960 13
962 5
964 32
966 13
968 8
970 8
971 1
972 23
973 2
974 12
975 1
976 22
978 7
979 1
980 14
982 8
984 22
985 1
986 6
988 17
989 1
990 6
992 13
994 19
996 11
998 4
1000 9
1002 2
1004 14
1006 5
1008 3
1010 9
1012 29
1014 6
1016 22
1017 1
1018 8
1019 1
1020 7
1022 6
1023 1
1024 10
1026 2
1028 8
1030 11
1031 2
1032 8
1034 9
1036 13
1038 12
1040 12
1042 3
1044 12
1046 3
1048 11
1050 2
1051 1
1052 2
1054 11
1056 6
1058 8
1059 1
1060 23
1062 6
1063 1
1064 8
1066 3
1068 6
1070 8
1071 1
1072 5
1074 3
1076 5
1078 3
1080 11
1081 1
1082 7
1084 18
1086 4
1087 1
1088 3
1090 3
1092 7
1094 3
1096 12
1098 6
1099 1
1100 2
1102 6
1104 14
1106 3
1108 6
1110 5
1112 2
1114 8
1116 3
1118 3
1120 7
1122 10
1124 6
1126 8
1128 1
1130 4
1132 3
1134 2
1136 5
1138 5
1140 8
1142 3
1144 7
1146 3
1148 11
1150 1
1152 5
1154 1
1156 5
1158 1
1160 5
1162 3
1164 6
1165 1
1166 1
1168 4
1169 1
1170 3
1171 1
1172 2
1174 5
1176 3
1177 1
1180 8
1182 2
1184 4
1186 2
1188 3
1190 2
1192 5
1194 6
1196 1
1198 2
1200 2
1204 10
1206 2
1208 9
1210 1
1214 6
1216 3
1218 4
1220 9
1221 2
1222 1
1224 5
1226 4
1228 8
1230 1
1232 1
1234 3
1236 5
1240 3
1242 1
1244 3
1245 1
1246 4
1248 6
1250 2
1252 7
1256 3
1258 2
1260 2
1262 3
1264 4
1265 1
1266 1
1270 1
1271 1
1272 2
1274 3
1276 3
1278 1
1280 3
1284 1
1286 1
1290 1
1292 3
1294 1
1296 7
1300 2
1302 4
1304 3
1306 2
1308 2
1312 1
1314 1
1316 3
1318 2
1320 1
1324 8
1326 1
1330 1
1331 1
1336 2
1338 1
1340 3
1341 1
1344 1
1346 2
1347 1
1348 3
1352 1
1354 2
1356 1
1358 1
1360 3
1362 1
1364 4
1366 1
1370 1
1372 3
1380 2
1384 2
1388 2
1390 2
1392 2
1394 1
1396 1
1398 1
1400 2
1402 1
1404 1
1406 1
1410 1
1412 5
1418 1
1420 1
1424 1
1432 2
1434 2
1442 3
1444 5
1448 1
1454 1
1456 1
1460 3
1462 4
1468 1
1474 1
1476 1
1478 2
1480 1
1486 2
1488 1
1492 1
1496 1
1500 3
1503 1
1506 1
1512 2
1516 1
1522 1
1524 2
1534 4
1536 1
1538 1
1540 2
1544 2
1548 1
1556 1
1560 1
1562 1
1564 2
1566 1
1568 1
1570 1
1572 1
1576 1
1590 1
1594 1
1604 1
1608 1
1614 1
1622 1
1624 2
1628 1
1629 1
1636 1
1642 1
1654 2
1660 1
1664 1
1670 1
1684 4
1698 1
1732 3
1742 1
1752 1
1760 1
1764 1
1772 2
1798 1
1808 1
1820 1
1852 1
1856 1
1874 1
1902 1
1908 1
1952 1
2004 1
2018 1
2020 1
2028 1
2174 1
2233 1
2244 1
2280 1
2290 1
2352 1
2604 1
4190 1
...@@ -91,18 +91,19 @@ def check_and_read_gif(img_path): ...@@ -91,18 +91,19 @@ def check_and_read_gif(img_path):
def load_vqa_bio_label_maps(label_map_path): def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin: with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines() lines = fin.readlines()
lines = [line.strip() for line in lines] old_lines = [line.strip() for line in lines]
if "O" not in lines: lines = ["O"]
lines.insert(0, "O") for line in old_lines:
labels = [] # "O" has already been in lines
for line in lines: if line.upper() in ["OTHER", "OTHERS", "IGNORE"]:
if line == "O": continue
labels.append("O") lines.append(line)
else: labels = ["O"]
labels.append("B-" + line) for line in lines[1:]:
labels.append("I-" + line) labels.append("B-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)} labels.append("I-" + line)
id2label_map = {idx: label for idx, label in enumerate(labels)} label2id_map = {label.upper(): idx for idx, label in enumerate(labels)}
id2label_map = {idx: label.upper() for idx, label in enumerate(labels)}
return label2id_map, id2label_map return label2id_map, id2label_map
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cv2
import os import os
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
...@@ -19,7 +20,7 @@ from PIL import Image, ImageDraw, ImageFont ...@@ -19,7 +20,7 @@ from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image, def draw_ser_results(image,
ocr_results, ocr_results,
font_path="doc/fonts/simfang.ttf", font_path="doc/fonts/simfang.ttf",
font_size=18): font_size=14):
np.random.seed(2021) np.random.seed(2021)
color = (np.random.permutation(range(255)), color = (np.random.permutation(range(255)),
np.random.permutation(range(255)), np.random.permutation(range(255)),
...@@ -40,9 +41,15 @@ def draw_ser_results(image, ...@@ -40,9 +41,15 @@ def draw_ser_results(image,
if ocr_info["pred_id"] not in color_map: if ocr_info["pred_id"] not in color_map:
continue continue
color = color_map[ocr_info["pred_id"]] color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) if "bbox" in ocr_info:
# draw with ocr engine
bbox = ocr_info["bbox"]
else:
# draw with ocr groundtruth
bbox = trans_poly_to_bbox(ocr_info["points"])
draw_box_txt(bbox, text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5) img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new) return np.array(img_new)
...@@ -62,6 +69,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color): ...@@ -62,6 +69,14 @@ def draw_box_txt(bbox, text, draw, font, font_size, color):
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def draw_re_results(image, def draw_re_results(image,
result, result,
font_path="doc/fonts/simfang.ttf", font_path="doc/fonts/simfang.ttf",
...@@ -80,10 +95,10 @@ def draw_re_results(image, ...@@ -80,10 +95,10 @@ def draw_re_results(image,
color_line = (0, 255, 0) color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result: for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
font_size, color_head) draw, font, font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
font_size, color_tail) draw, font, font_size, color_tail)
center_head = ( center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
...@@ -96,3 +111,16 @@ def draw_re_results(image, ...@@ -96,3 +111,16 @@ def draw_re_results(image,
img_new = Image.blend(image, img_new, 0.5) img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new) return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
...@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类 ...@@ -16,7 +16,7 @@ SDMGR是一个关键信息提取算法,将每个检测到的文本区域分类
训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集: 训练和测试的数据采用wildreceipt数据集,通过如下指令下载数据集:
``` ```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
``` ```
执行预测: 执行预测:
......
...@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu ...@@ -15,7 +15,7 @@ This section provides a tutorial example on how to quickly use, train, and evalu
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget: [Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/wildreceipt.tar && tar xf wildreceipt.tar
``` ```
Download the pretrained model and predict the result: Download the pretrained model and predict the result:
......
# PP-Structure 系列模型列表 # PP-Structure 系列模型列表
- [1. 版面分析模型](#1) - [1. 版面分析模型](#1-版面分析模型)
- [2. OCR和表格识别模型](#2) - [2. OCR和表格识别模型](#2-ocr和表格识别模型)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 表格识别模型](#22) - [2.2 表格识别模型](#22-表格识别模型)
- [3. VQA模型](#3) - [3. VQA模型](#3-vqa模型)
- [4. KIE模型](#4) - [4. KIE模型](#4-kie模型)
<a name="1"></a> <a name="1"></a>
...@@ -35,18 +35,18 @@ ...@@ -35,18 +35,18 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | |en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a> <a name="3"></a>
## 3. VQA模型 ## 3. VQA模型
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE模型 ## 4. KIE模型
......
# PP-Structure Model list # PP-Structure Model list
- [1. Layout Analysis](#1) - [1. Layout Analysis](#1-layout-analysis)
- [2. OCR and Table Recognition](#2) - [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 Table Recognition](#22) - [2.2 Table Recognition](#22-table-recognition)
- [3. VQA](#3) - [3. VQA](#3-vqa)
- [4. KIE](#4) - [4. KIE](#4-kie)
<a name="1"></a> <a name="1"></a>
...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model ...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download| |model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- | | --- |----------------------------------------------------------------| --- | --- |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE ## 4. KIE
......
...@@ -34,7 +34,7 @@ from ppocr.utils.logging import get_logger ...@@ -34,7 +34,7 @@ from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.docx import convert_info_docx from ppstructure.recovery.recovery_to_doc import convert_info_docx
logger = get_logger() logger = get_logger()
......
...@@ -44,6 +44,12 @@ python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simp ...@@ -44,6 +44,12 @@ python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simp
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
* **(2)安装依赖**
```bash
python3 -m pip install -r ppstructure/recovery/requirements.txt
```
<a name="2.2"></a> <a name="2.2"></a>
### 2.2 安装PaddleOCR ### 2.2 安装PaddleOCR
......
...@@ -23,43 +23,63 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth' ...@@ -23,43 +23,63 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import numpy as np import numpy as np
import time import time
import json
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.data import create_operators, transform from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.visual import draw_rectangle
from ppstructure.utility import parse_args from ppstructure.utility import parse_args
logger = get_logger() logger = get_logger()
def build_pre_process_list(args):
resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
pad_op = {
'PaddingTableImage': {
'size': [args.table_max_len, args.table_max_len]
}
}
normalize_op = {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'mean': [0.485, 0.456, 0.406] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'scale': '1./255.',
'order': 'hwc'
}
}
to_chw_op = {'ToCHWImage': None}
keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
if args.table_algorithm not in ['TableMaster']:
pre_process_list = [
resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
]
else:
pre_process_list = [
resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
]
return pre_process_list
class TableStructurer(object): class TableStructurer(object):
def __init__(self, args): def __init__(self, args):
pre_process_list = [{ pre_process_list = build_pre_process_list(args)
'ResizeTableImage': { if args.table_algorithm not in ['TableMaster']:
'max_len': args.table_max_len postprocess_params = {
} 'name': 'TableLabelDecode',
}, { "character_dict_path": args.table_char_dict_path,
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
} }
}, { else:
'PaddingTableImage': None postprocess_params = {
}, { 'name': 'TableMasterLabelDecode',
'ToCHWImage': None "character_dict_path": args.table_char_dict_path,
}, { 'box_shape': 'pad'
'KeepKeys': {
'keep_keys': ['image']
} }
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
}
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
...@@ -88,27 +108,17 @@ class TableStructurer(object): ...@@ -88,27 +108,17 @@ class TableStructurer(object):
preds['structure_probs'] = outputs[1] preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0] preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds) shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc'] structure_str_list = post_result['structure_batch_list'][0]
imgh, imgw = ori_im.shape[0:2] bbox_list = post_result['bbox_batch_list'][0]
res_loc_final = [] structure_str_list = structure_str_list[0]
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
structure_str_list = [ structure_str_list = [
'<html>', '<body>', '<table>' '<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>'] ] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse return (structure_str_list, bbox_list), elapse
def main(args): def main(args):
...@@ -116,21 +126,35 @@ def main(args): ...@@ -116,21 +126,35 @@ def main(args):
table_structurer = TableStructurer(args) table_structurer = TableStructurer(args)
count = 0 count = 0
total_time = 0 total_time = 0
for image_file in image_file_list: use_xywh = args.table_algorithm in ['TableMaster']
img, flag = check_and_read_gif(image_file) os.makedirs(args.output, exist_ok=True)
if not flag: with open(
img = cv2.imread(image_file) os.path.join(args.output, 'infer.txt'), mode='w',
if img is None: encoding='utf-8') as f_w:
logger.info("error in loading image:{}".format(image_file)) for image_file in image_file_list:
continue img, flag = check_and_read_gif(image_file)
structure_res, elapse = table_structurer(img) if not flag:
img = cv2.imread(image_file)
logger.info("result: {}".format(structure_res)) if img is None:
logger.info("error in loading image:{}".format(image_file))
if count > 0: continue
total_time += elapse structure_res, elapse = table_structurer(img)
count += 1 structure_str_list, bbox_list = structure_res
logger.info("Predict time of {}: {}".format(image_file, elapse)) bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -129,11 +129,25 @@ class TableSystem(object): ...@@ -129,11 +129,25 @@ class TableSystem(object):
def rebuild_table(self, structure_res, dt_boxes, rec_res): def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res pred_structures, pred_bboxes = structure_res
dt_boxes, rec_res = self.filter_ocr_result(pred_bboxes,dt_boxes, rec_res)
matched_index = self.match_result(dt_boxes, pred_bboxes) matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res) rec_res)
return pred_html, pred return pred_html, pred
def filter_ocr_result(self, pred_bboxes,dt_boxes, rec_res):
y1 = pred_bboxes[:,1::2].min()
new_dt_boxes = []
new_rec_res = []
for box,rec in zip(dt_boxes, rec_res):
if np.max(box[1::2]) < y1:
continue
new_dt_boxes.append(box)
new_rec_res.append(rec)
return new_dt_boxes, new_rec_res
def match_result(self, dt_boxes, pred_bboxes): def match_result(self, dt_boxes, pred_bboxes):
matched = {} matched = {}
for i, gt_box in enumerate(dt_boxes): for i, gt_box in enumerate(dt_boxes):
......
...@@ -25,6 +25,7 @@ def init_args(): ...@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output') parser.add_argument("--output", type=str, default='./output')
# params for table structure # params for table structure
parser.add_argument("--table_max_len", type=int, default=488) parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str) parser.add_argument("--table_model_dir", type=str)
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
...@@ -40,6 +41,13 @@ def init_args(): ...@@ -40,6 +41,13 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
parser.add_argument(
"--ser_dict_path",
type=str,
default="../train_data/XFUND/class_list_xfun.txt")
# params for inference # params for inference
parser.add_argument( parser.add_argument(
"--mode", "--mode",
...@@ -65,7 +73,7 @@ def init_args(): ...@@ -65,7 +73,7 @@ def init_args():
"--recovery", "--recovery",
type=bool, type=bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
return parser return parser
......
English | [简体中文](README_ch.md) English | [简体中文](README_ch.md)
- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering) - [1 Introduction](#1-introduction)
- [1. Introduction](#1-Introduction) - [2. Performance](#2-performance)
- [2. Performance](#2-performance) - [3. Effect demo](#3-effect-demo)
- [3. Effect demo](#3-Effect-demo) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. Install](#4-install)
- [4. Install](#4-Install) - [4.1 Install dependencies](#41-install-dependencies)
- [4.1 Installation dependencies](#41-Install-dependencies) - [5.3 RE](#53-re)
- [4.2 Install PaddleOCR](#42-Install-PaddleOCR) - [6. Reference Links](#6-reference-links)
- [5. Usage](#5-Usage) - [License](#license)
- [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. Reference](#6-Reference-Links)
# Document Visual Question Answering # Document Visual Question Answering
...@@ -125,13 +121,13 @@ If you want to experience the prediction process directly, you can download the ...@@ -125,13 +121,13 @@ If you want to experience the prediction process directly, you can download the
* Download the processed dataset * Download the processed dataset
The download address of the processed XFUND Chinese dataset: [https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar). The download address of the processed XFUND Chinese dataset: [link](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar).
Download and unzip the dataset, and place the dataset in the current directory after unzipping. Download and unzip the dataset, and place the dataset in the current directory after unzipping.
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
```` ````
* Convert the dataset * Convert the dataset
...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```` ````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* Use `OCR engine + SER` tandem prediction * `OCR + SER` tandem prediction based on training engine
Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example: Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* End-to-end evaluation of `OCR engine + SER` prediction system * End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate. First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o ...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```` ````
* export model
Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
* `OCR + SER` tandem prediction based on prediction engine
Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
<a name="53"></a> <a name="53"></a>
### 5.3 RE ### 5.3 RE
...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed ...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example: Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* export model
cooming soon
* `OCR + SER + RE` tandem prediction based on prediction engine
cooming soon
## 6. Reference Links ## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
[English](README.md) | 简体中文 [English](README.md) | 简体中文
- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa) - [1. 简介](#1-简介)
- [1. 简介](#1-简介) - [2. 性能](#2-性能)
- [2. 性能](#2-性能) - [3. 效果演示](#3-效果演示)
- [3. 效果演示](#3-效果演示) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. 安装](#4-安装)
- [4. 安装](#4-安装) - [4.1 安装依赖](#41-安装依赖)
- [4.1 安装依赖](#41-安装依赖) - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa) - [5. 使用](#5-使用)
- [5. 使用](#5-使用) - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备) - [5.2 SER](#52-ser)
- [5.2 SER](#52-ser) - [5.3 RE](#53-re)
- [5.3 RE](#53-re) - [6. 参考链接](#6-参考链接)
- [6. 参考链接](#6-参考链接) - [License](#license)
# 文档视觉问答(DOC-VQA) # 文档视觉问答(DOC-VQA)
...@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt ...@@ -122,13 +122,13 @@ python3 -m pip install -r ppstructure/vqa/requirements.txt
* 下载处理好的数据集 * 下载处理好的数据集
处理好的XFUND中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUND中文数据集下载地址:[链接](https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar)
下载并解压该数据集,解压后将数据集放置在当前目录下。 下载并解压该数据集,解压后将数据集放置在当前目录下。
```shell ```shell
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar
``` ```
* 转换数据集 * 转换数据集
...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER`串联预测 * 基于训练引擎的`OCR + SER`串联预测
使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。 首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l ...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
* 模型导出
使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
* 基于预测引擎的`OCR + SER`串联预测
使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE ### 5.3 RE
...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o ...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER + RE`串联预测 * 基于训练引擎的`OCR + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
* 模型导出
cooming soon
* 基于预测引擎的`OCR + SER + RE`串联预测
cooming soon
## 6. 参考链接 ## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from paddleocr import PaddleOCR
logger = get_logger()
class SerPredictor(object):
def __init__(self, args):
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
pre_process_list = [{
'VQATokenLabelEncode': {
'algorithm': args.vqa_algorithm,
'class_path': args.ser_dict_path,
'contains_re': False,
'ocr_engine': self.ocr_engine
}
}, {
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
'std': [58.395, 57.12, 57.375],
'mean': [123.675, 116.28, 103.53],
'scale': '1',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
}]
postprocess_params = {
'name': 'VQASerTokenLayoutLMPostProcess',
"class_path": args.ser_dict_path,
}
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
for idx in range(len(self.input_tensor)):
expand_input = np.expand_dims(data[idx], axis=0)
self.input_tensor[idx].copy_from_cpu(expand_input)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
post_result = self.postprocess_op(
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
elapse = time.time() - starttime
return post_result, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerPredictor(args)
count = 0
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_ser_results(
image_file,
ser_res,
font_path=args.vis_font_path, )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
sentencepiece sentencepiece
yacs yacs
seqeval seqeval
paddlenlp>=2.2.1 paddlenlp>=2.2.1
\ No newline at end of file pypandoc
attrdict
python_docx
\ No newline at end of file
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import cv2
import numpy as np
from copy import deepcopy
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def get_outer_poly(bbox_list):
x1 = min([bbox[0] for bbox in bbox_list])
y1 = min([bbox[1] for bbox in bbox_list])
x2 = max([bbox[2] for bbox in bbox_list])
y2 = max([bbox[3] for bbox in bbox_list])
return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def load_funsd_label(image_dir, anno_dir):
imgs = os.listdir(image_dir)
annos = os.listdir(anno_dir)
imgs = [img.replace(".png", "") for img in imgs]
annos = [anno.replace(".json", "") for anno in annos]
fn_info_map = dict()
for anno_fn in annos:
res = []
with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin:
infos = json.load(fin)
infos = infos["form"]
old_id2new_id_map = dict()
global_new_id = 0
for info in infos:
if info["text"] is None:
continue
words = info["words"]
if len(words) <= 0:
continue
word_idx = 1
curr_bboxes = [words[0]["box"]]
curr_texts = [words[0]["text"]]
while word_idx < len(words):
# switch to a new link
if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][
"box"][2]:
if len("".join(curr_texts[0])) > 0:
res.append({
"transcription": " ".join(curr_texts),
"label": info["label"],
"points": get_outer_poly(curr_bboxes),
"linking": info["linking"],
"id": global_new_id,
})
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
global_new_id += 1
curr_bboxes = [words[word_idx]["box"]]
curr_texts = [words[word_idx]["text"]]
else:
curr_bboxes.append(words[word_idx]["box"])
curr_texts.append(words[word_idx]["text"])
word_idx += 1
if len("".join(curr_texts[0])) > 0:
res.append({
"transcription": " ".join(curr_texts),
"label": info["label"],
"points": get_outer_poly(curr_bboxes),
"linking": info["linking"],
"id": global_new_id,
})
if info["id"] not in old_id2new_id_map:
old_id2new_id_map[info["id"]] = []
old_id2new_id_map[info["id"]].append(global_new_id)
global_new_id += 1
res = sorted(
res, key=lambda r: (r["points"][0][1], r["points"][0][0]))
for i in range(len(res) - 1):
for j in range(i, 0, -1):
if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \
(res[j + 1]["points"][0][0] < res[j]["points"][0][0]):
tmp = deepcopy(res[j])
res[j] = deepcopy(res[j + 1])
res[j + 1] = deepcopy(tmp)
else:
break
# re-generate unique ids
for idx, r in enumerate(res):
new_links = []
for link in r["linking"]:
# illegal links will be removed
if link[0] not in old_id2new_id_map or link[
1] not in old_id2new_id_map:
continue
for src in old_id2new_id_map[link[0]]:
for dst in old_id2new_id_map[link[1]]:
new_links.append([src, dst])
res[idx]["linking"] = deepcopy(new_links)
fn_info_map[anno_fn] = res
return fn_info_map
def main():
test_image_dir = "train_data/FUNSD/testing_data/images/"
test_anno_dir = "train_data/FUNSD/testing_data/annotations/"
test_output_dir = "train_data/FUNSD/test.json"
fn_info_map = load_funsd_label(test_image_dir, test_anno_dir)
with open(test_output_dir, "w") as fout:
for fn in fn_info_map:
fout.write(fn + ".png" + "\t" + json.dumps(
fn_info_map[fn], ensure_ascii=False) + "\n")
train_image_dir = "train_data/FUNSD/training_data/images/"
train_anno_dir = "train_data/FUNSD/training_data/annotations/"
train_output_dir = "train_data/FUNSD/train.json"
fn_info_map = load_funsd_label(train_image_dir, train_anno_dir)
with open(train_output_dir, "w") as fout:
for fn in fn_info_map:
fout.write(fn + ".png" + "\t" + json.dumps(
fn_info_map[fn], ensure_ascii=False) + "\n")
print("====ok====")
return
if __name__ == "__main__":
main()
...@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None): ...@@ -21,26 +21,22 @@ def transfer_xfun_data(json_path=None, output_file=None):
json_info = json.loads(lines[0]) json_info = json.loads(lines[0])
documents = json_info["documents"] documents = json_info["documents"]
label_info = {}
with open(output_file, "w", encoding='utf-8') as fout: with open(output_file, "w", encoding='utf-8') as fout:
for idx, document in enumerate(documents): for idx, document in enumerate(documents):
label_info = []
img_info = document["img"] img_info = document["img"]
document = document["document"] document = document["document"]
image_path = img_info["fname"] image_path = img_info["fname"]
label_info["height"] = img_info["height"]
label_info["width"] = img_info["width"]
label_info["ocr_info"] = []
for doc in document: for doc in document:
label_info["ocr_info"].append({ x1, y1, x2, y2 = doc["box"]
"text": doc["text"], points = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
label_info.append({
"transcription": doc["text"],
"label": doc["label"], "label": doc["label"],
"bbox": doc["box"], "points": points,
"id": doc["id"], "id": doc["id"],
"linking": doc["linking"], "linking": doc["linking"]
"words": doc["words"]
}) })
fout.write(image_path + "\t" + json.dumps( fout.write(image_path + "\t" + json.dumps(
......
...@@ -21,6 +21,18 @@ function func_parser_params(){ ...@@ -21,6 +21,18 @@ function func_parser_params(){
echo ${tmp} echo ${tmp}
} }
function set_dynamic_epoch(){
string=$1
num=$2
_str=${string:1:6}
IFS="C"
arr=(${_str})
M=${arr[0]}
P=${arr[1]}
ep=`expr $num \* $M \* $P`
echo $ep
}
function func_sed_params(){ function func_sed_params(){
filename=$1 filename=$1
line=$2 line=$2
...@@ -139,10 +151,11 @@ else ...@@ -139,10 +151,11 @@ else
device_num=${params_list[4]} device_num=${params_list[4]}
IFS=";" IFS=";"
if [ ${precision} = "null" ];then if [ ${precision} = "fp16" ];then
precision="fp32" precision="amp"
fi fi
epoch=$(set_dynamic_epoch $device_num $epoch)
fp_items_list=($precision) fp_items_list=($precision)
batch_size_list=($batch_size) batch_size_list=($batch_size)
device_num_list=($device_num) device_num_list=($device_num)
...@@ -150,10 +163,16 @@ fi ...@@ -150,10 +163,16 @@ fi
IFS="|" IFS="|"
for batch_size in ${batch_size_list[*]}; do for batch_size in ${batch_size_list[*]}; do
for precision in ${fp_items_list[*]}; do for train_precision in ${fp_items_list[*]}; do
for device_num in ${device_num_list[*]}; do for device_num in ${device_num_list[*]}; do
# sed batchsize and precision # sed batchsize and precision
func_sed_params "$FILENAME" "${line_precision}" "$precision" if [ ${train_precision} = "amp" ];then
precision="fp16"
else
precision="fp32"
fi
func_sed_params "$FILENAME" "${line_precision}" "$train_precision"
func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size" func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size"
func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch" func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch"
gpu_id=$(set_gpu_id $device_num) gpu_id=$(set_gpu_id $device_num)
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o norm_train:tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,9 @@ null:null ...@@ -51,3 +51,9 @@ null:null
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -6,7 +6,7 @@ Global: ...@@ -6,7 +6,7 @@ Global:
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec_pp-OCRv2_distillation save_model_dir: ./output/rec_pp-OCRv2_distillation
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: [0, 2000] eval_batch_step: [0, 200000]
cal_metric_during_train: true cal_metric_during_train: true
pretrained_model: pretrained_model:
checkpoints: checkpoints:
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference ...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,9 @@ null:null ...@@ -51,3 +51,9 @@ null:null
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,320]}] random_infer_input:[{float32,[3,32,320]}]
===========================train_benchmark_params==========================
batch_size:64
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.print_batch_step=1 Train.loader.shuffle=false Global.eval_batch_step=[4000,400]
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,9 @@ null:null ...@@ -51,3 +51,9 @@ null:null
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -153,7 +153,7 @@ Train: ...@@ -153,7 +153,7 @@ Train:
data_dir: ./train_data/ic15_data/ data_dir: ./train_data/ic15_data/
ext_op_transform_idx: 1 ext_op_transform_idx: 1
label_file_list: label_file_list:
- ./train_data/ic15_data/rec_gt_train_lite.txt - ./train_data/ic15_data/rec_gt_train.txt
transforms: transforms:
- DecodeImage: - DecodeImage:
img_mode: BGR img_mode: BGR
...@@ -183,7 +183,7 @@ Eval: ...@@ -183,7 +183,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/ic15_data data_dir: ./train_data/ic15_data
label_file_list: label_file_list:
- ./train_data/ic15_data/rec_gt_test_lite.txt - ./train_data/ic15_data/rec_gt_test.txt
transforms: transforms:
- DecodeImage: - DecodeImage:
img_mode: BGR img_mode: BGR
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference ...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,10 @@ null:null ...@@ -51,3 +51,10 @@ null:null
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,320]}] random_infer_input:[{float32,[3,48,320]}]
===========================train_benchmark_params==========================
batch_size:64
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0 model_name:ch_ppocr_mobile_v2_0
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False infer_quant:False
......
===========================ch_ppocr_mobile_v2.0=========================== ===========================ch_ppocr_mobile_v2.0===========================
model_name:ch_ppocr_mobile_v2.0 model_name:ch_ppocr_mobile_v2_0
python:python3.7 python:python3.7
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_export:null infer_export:null
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_mobile_v2.0 model_name:ch_ppocr_mobile_v2_0
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/ --det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0 model_name:ch_ppocr_mobile_v2_0
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0 model_name:ch_ppocr_mobile_v2_0
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/ infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False infer_quant:False
......
===========================infer_params=========================== ===========================infer_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python python:python
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer
infer_export:null infer_export:null
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/ --det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -50,4 +50,10 @@ null:null ...@@ -50,4 +50,10 @@ null:null
--benchmark:True --benchmark:True
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file ===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python3.7 python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1 gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True Global.use_gpu:True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det model_name:ch_ppocr_mobile_v2_0_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det_PACT model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================kl_quant_params=========================== ===========================kl_quant_params===========================
model_name:ch_ppocr_mobile_v2.0_det_KL model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7 python:python3.7
Global.pretrained_model:null Global.pretrained_model:null
Global.save_inference_dir:null Global.save_inference_dir:null
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det_FPGM model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_det_FPGM model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_det_KL model_name:ch_ppocr_mobile_v2_0_det_KL
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer infer_model:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer
infer_quant:False infer_quant:False
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_KL model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_det_KL model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_det_PACT model_name:ch_ppocr_mobile_v2_0_det_PACT
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_pact_infer infer_model:./inference/ch_ppocr_mobile_v2.0_det_pact_infer
infer_quant:False infer_quant:False
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_PACT model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_det_PACT model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/ infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
infer_quant:False infer_quant:False
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir: --det_model_dir:
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:null --det_dirname:null
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference ...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o norm_train:tools/train.py -c configs/rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py ...@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}] random_infer_input:[{float32,[3,32,100]}]
===========================train_benchmark_params==========================
batch_size:256
fp_items:fp32|fp16
epoch:3
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7 python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1 gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True Global.use_gpu:True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_PACT model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================kl_quant_params=========================== ===========================kl_quant_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_KL model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7 python:python3.7
Global.pretrained_model:null Global.pretrained_model:null
Global.save_inference_dir:null Global.save_inference_dir:null
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_FPGM model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_FPGM model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_KL model_name:ch_ppocr_mobile_v2_0_rec_KL
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_klquant_infer infer_model:./inference/ch_ppocr_mobile_v2.0_rec_klquant_infer
infer_quant:False infer_quant:False
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_det_KL model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_KL model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:null --det_dirname:null
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_PACT model_name:ch_ppocr_mobile_v2_0_rec_PACT
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_pact_infer infer_model:./inference/ch_ppocr_mobile_v2.0_rec_pact_infer
infer_quant:False infer_quant:False
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_det_PACT model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/ --det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_PACT model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:null --det_dirname:null
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_server_v2.0 model_name:ch_ppocr_server_v2_0
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/ infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False infer_quant:False
......
===========================ch_ppocr_server_v2.0=========================== ===========================ch_ppocr_server_v2.0===========================
model_name:ch_ppocr_server_v2.0 model_name:ch_ppocr_server_v2_0
python:python3.7 python:python3.7
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/ infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_export:null infer_export:null
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_server_v2.0 model_name:ch_ppocr_server_v2_0
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/ --det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_server_v2.0 model_name:ch_ppocr_server_v2_0
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_server_v2.0 model_name:ch_ppocr_server_v2_0
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
......
...@@ -2,13 +2,13 @@ Global: ...@@ -2,13 +2,13 @@ Global:
use_gpu: false use_gpu: false
epoch_num: 5 epoch_num: 5
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 1 print_batch_step: 2
save_model_dir: ./output/db_mv3/ save_model_dir: ./output/db_mv3/
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0, 400] eval_batch_step: [0, 30000]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/ infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False infer_quant:False
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/ --det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/ --det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
quant_train:null quant_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py ...@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py
--benchmark:True --benchmark:True
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file ===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
python:python3.7 python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1 gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True Global.use_gpu:True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_det model_name:ch_ppocr_server_v2_0_det
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================cpp_infer_params=========================== ===========================cpp_infer_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
use_opencv:True use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_rec_infer/ infer_model:./inference/ch_ppocr_server_v2.0_rec_infer/
infer_quant:False infer_quant:False
......
===========================paddle2onnx_params=========================== ===========================paddle2onnx_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
python:python3.7 python:python3.7
2onnx: paddle2onnx 2onnx: paddle2onnx
--det_model_dir: --det_model_dir:
......
===========================serving_params=========================== ===========================serving_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
python:python3.7 python:python3.7
trans_model:-m paddle_serving_client.convert trans_model:-m paddle_serving_client.convert
--det_dirname:null --det_dirname:null
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference ...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py ...@@ -51,3 +51,9 @@ inference:tools/infer/predict_rec.py
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}] random_infer_input:[{float32,[3,32,100]}]
===========================train_benchmark_params==========================
batch_size:256
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
python:python3.7 python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1 gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True Global.use_gpu:True
......
===========================train_params=========================== ===========================train_params===========================
model_name:ch_ppocr_server_v2.0_rec model_name:ch_ppocr_server_v2_0_rec
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -52,8 +52,8 @@ null:null ...@@ -52,8 +52,8 @@ null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:8|16 batch_size:16
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:4
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
===========================train_params=========================== ===========================train_params===========================
model_name:det_mv3_east_v2.0 model_name:det_mv3_east_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:det_mv3_pse_v2.0 model_name:det_mv3_pse_v2_0
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0
Global.use_gpu:True|True Global.use_gpu:True|True
......
...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:8|16 batch_size:8|16
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
===========================train_params===========================
model_name:det_r50_db_plusplus
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.pretrained_model=./pretrain_models/ResNet50_dcn_asf_synthtext_pretrained
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_db++_train/best_accuracy
infer_export:tools/export_model.py -c configs/det/det_r50_db++_icdar15.yml -o
infer_quant:False
inference:tools/infer/predict_det.py --det_algorithm="DB++"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_params=========================== ===========================train_params===========================
model_name:det_r50_db_v2.0 model_name:det_r50_db_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
quant_export:null quant_export:null
fpgm_export:null fpgm_export:null
distill_train:null distill_train:null
...@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py ...@@ -50,4 +50,10 @@ inference:tools/infer/predict_det.py
--benchmark:True --benchmark:True
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file ===========================train_benchmark_params==========================
batch_size:8
fp_items:fp32|fp16
epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
Global:
use_gpu: true
epoch_num: 1500
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/det_r50_dcn_fce_ctw/
save_epoch_step: 100
# evaluation is run every 835 iterations
eval_batch_step: [0, 4000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_fce/predicts_fce.txt
Architecture:
model_type: det
algorithm: FCE
Transform:
Backbone:
name: ResNet_vd
layers: 50
dcn_stage: [False, True, True, True]
out_indices: [1,2,3]
Neck:
name: FCEFPN
out_channels: 256
has_extra_convs: False
extra_stage: 0
Head:
name: FCEHead
fourier_degree: 5
Loss:
name: FCELoss
fourier_degree: 5
num_sample: 50
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.0001
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: FCEPostProcess
scales: [8, 16, 32]
alpha: 1.0
beta: 1.0
fourier_degree: 5
box_type: 'poly'
Metric:
name: DetFCEMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
ignore_orientation: True
- DetLabelEncode: # Class handling label
- ColorJitter:
brightness: 0.142
saturation: 0.5
contrast: 0.5
- RandomScaling:
- RandomCropFlip:
crop_ratio: 0.5
- RandomCropPolyInstances:
crop_ratio: 0.8
min_side_ratio: 0.3
- RandomRotatePolyInstances:
rotate_ratio: 0.5
max_angle: 30
pad_with_fixed_color: False
- SquareResizePad:
target_size: 800
pad_ratio: 0.6
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- FCENetTargets:
fourier_degree: 5
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'p3_maps', 'p4_maps', 'p5_maps'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 6
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
ignore_orientation: True
- DetLabelEncode: # Class handling label
- DetResizeForTest:
limit_type: 'min'
limit_side_len: 736
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- Pad:
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
===========================train_params===========================
model_name:det_r50_dcn_fce_ctw_v2_0
python:python3.7
gpu_list:0
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
--benchmark:True
--det_algorithm:FCE
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params==========================
batch_size:6
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: EAST algorithm: EAST
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: EASTFPN name: EASTFPN
......
===========================train_params=========================== ===========================train_params===========================
model_name:det_r50_vd_east_v2_0 model_name:det_r50_vd_east_v2_0
python:python3.7 python:python3.7
gpu_list:0 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500 Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=500
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:null Global.pretrained_model:./pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy
train_model_name:latest train_model_name:latest
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_east_v2_0/det_r50_vd_east.yml -o Global.pretrained_model=pretrain_models/det_r50_vd_east_v2.0_train/best_accuracy.pdparams Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -55,4 +55,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ...@@ -55,4 +55,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
batch_size:8 batch_size:8
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
\ No newline at end of file flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
...@@ -8,7 +8,7 @@ Global: ...@@ -8,7 +8,7 @@ Global:
# evaluation is run every 125 iterations # evaluation is run every 125 iterations
eval_batch_step: [ 0,1000 ] eval_batch_step: [ 0,1000 ]
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
...@@ -20,7 +20,7 @@ Architecture: ...@@ -20,7 +20,7 @@ Architecture:
algorithm: PSE algorithm: PSE
Transform: Transform:
Backbone: Backbone:
name: ResNet name: ResNet_vd
layers: 50 layers: 50
Neck: Neck:
name: FPN name: FPN
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ ...@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_pse_v2_0/det_r50_vd_pse.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -56,3 +56,4 @@ batch_size:8 ...@@ -56,3 +56,4 @@ batch_size:8
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:2
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
\ No newline at end of file
===========================train_params=========================== ===========================train_params===========================
model_name:det_r50_vd_sast_icdar15_v2.0 model_name:det_r50_vd_sast_icdar15_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:det_r50_vd_sast_totaltext_v2.0 model_name:det_r50_vd_sast_totaltext_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
...@@ -6,19 +6,18 @@ Global: ...@@ -6,19 +6,18 @@ Global:
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 40000]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/table/table.jpg infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
...@@ -44,11 +43,8 @@ Architecture: ...@@ -44,11 +43,8 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
...@@ -61,28 +57,34 @@ PostProcess: ...@@ -61,28 +57,34 @@ PostProcess:
Metric: Metric:
name: TableMetric name: TableMetric
main_indicator: acc main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train: Train:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: ./train_data/pubtabnet/train data_dir: ./train_data/pubtabnet/train
label_file_path: ./train_data/pubtabnet/train.jsonl label_file_list: [./train_data/pubtabnet/train.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: True shuffle: True
batch_size_per_card: 32 batch_size_per_card: 32
...@@ -93,23 +95,28 @@ Eval: ...@@ -93,23 +95,28 @@ Eval:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: ./train_data/pubtabnet/test/ data_dir: ./train_data/pubtabnet/test/
label_file_path: ./train_data/pubtabnet/test.jsonl label_file_list: [./train_data/pubtabnet/test.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -13,7 +13,7 @@ train_infer_img_dir:./ppstructure/docs/table/table.jpg ...@@ -13,7 +13,7 @@ train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o norm_train:tools/train.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -27,7 +27,7 @@ null:null ...@@ -27,7 +27,7 @@ null:null
===========================infer_params=========================== ===========================infer_params===========================
Global.save_inference_dir:./output/ Global.save_inference_dir:./output/
Global.checkpoints: Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o norm_export:tools/export_model.py -c test_tipc/configs/en_table_structure/table_mv3.yml -o
quant_export: quant_export:
fpgm_export: fpgm_export:
distill_export:null distill_export:null
...@@ -51,3 +51,9 @@ null:null ...@@ -51,3 +51,9 @@ null:null
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,488,488]}] random_infer_input:[{float32,[3,488,488]}]
===========================train_benchmark_params==========================
batch_size:32
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -6,7 +6,7 @@ Global.use_gpu:True|True ...@@ -6,7 +6,7 @@ Global.use_gpu:True|True
Global.auto_cast:fp32 Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50 Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=50
Global.save_model_dir:./output/ Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128 Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=2
Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy Global.pretrained_model:./pretrain_models/en_ppocr_mobile_v2.0_table_structure_train/best_accuracy
train_model_name:latest train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg train_infer_img_dir:./ppstructure/docs/table/table.jpg
......
...@@ -6,7 +6,7 @@ Global: ...@@ -6,7 +6,7 @@ Global:
save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_model_dir: ./output/rec/mv3_none_bilstm_ctc/
save_epoch_step: 3 save_epoch_step: 3
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0, 2000] eval_batch_step: [0, 20000]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_mv3_none_bilstm_ctc_v2.0 model_name:rec_mv3_none_bilstm_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference ...@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null null:null
## ##
trainer:norm_train trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
...@@ -50,4 +50,10 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic ...@@ -50,4 +50,10 @@ inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dic
--benchmark:True --benchmark:True
null:null null:null
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}] random_infer_input:[{float32,[3,32,100]}]
\ No newline at end of file ===========================train_benchmark_params==========================
batch_size:256
fp_items:fp32|fp16
epoch:4
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
===========================train_params=========================== ===========================train_params===========================
model_name:rec_mv3_none_none_ctc_v2.0 model_name:rec_mv3_none_none_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_mv3_tps_bilstm_att_v2.0 model_name:rec_mv3_tps_bilstm_att_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_mv3_tps_bilstm_ctc_v2.0 model_name:rec_mv3_tps_bilstm_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
Global:
use_gpu: True
epoch_num: 6
log_smooth_window: 50
print_batch_step: 50
save_model_dir: ./output/rec/rec_r32_gaspin_bilstm_att/
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model: pretrain_models/rec_r32_gaspin_bilstm_att_train/best_accuracy
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ./ppocr/utils/dict/spin_dict.txt
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_r32_gaspin_bilstm_att.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs: [3, 4, 5]
values: [0.001, 0.0003, 0.00009, 0.000027]
clip_norm: 5
Architecture:
model_type: rec
algorithm: SPIN
in_channels: 1
Transform:
name: GA_SPIN
offsets: True
default_type: 6
loc_lr: 0.1
stn: True
Backbone:
name: ResNet32
out_channels: 512
Neck:
name: SequenceEncoder
encoder_type: cascadernn
hidden_size: 256
out_channels: [256, 512]
with_linear: True
Head:
name: SPINAttentionHead
hidden_size: 256
Loss:
name: SPINAttentionLoss
ignore_index: 0
PostProcess:
name: SPINLabelDecode
use_space_char: False
Metric:
name: RecMetric
main_indicator: acc
is_filter: True
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINLabelEncode: # Class handling label
- SPINRecResizeImg:
image_shape: [100, 32]
interpolation : 2
mean: [127.5]
std: [127.5]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 128
drop_last: True
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SPINLabelEncode: # Class handling label
- SPINRecResizeImg:
image_shape: [100, 32]
interpolation : 2
mean: [127.5]
std: [127.5]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 1
===========================train_params===========================
model_name:rec_r32_gaspin_bilstm_att
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=64
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/rec_r32_gaspin_bilstm_att/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1|6
--use_tensorrt:False
--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,32,100]}]
===========================train_params=========================== ===========================train_params===========================
model_name:rec_r34_vd_none_bilstm_ctc_v2.0 model_name:rec_r34_vd_none_bilstm_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_r34_vd_none_none_ctc_v2.0 model_name:rec_r34_vd_none_none_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_r34_vd_tps_bilstm_att_v2.0 model_name:rec_r34_vd_tps_bilstm_att_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
===========================train_params=========================== ===========================train_params===========================
model_name:rec_r34_vd_tps_bilstm_ctc_v2.0 model_name:rec_r34_vd_tps_bilstm_ctc_v2_0
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/train
label_file_list: [./train_data/pubtabnet/train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: ./train_data/pubtabnet/test/
label_file_list: [./train_data/pubtabnet/test.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
===========================train_params===========================
model_name:table_master
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:./pretrain_models/table_structure_tablemaster_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=10
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/table_master/table_master.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --output ./output/table --table_algorithm=TableMaster --table_max_len=480
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--table_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,480,480]}]
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
```shell ```shell
# 运行格式:bash test_tipc/prepare.sh train_benchmark.txt mode # 运行格式:bash test_tipc/prepare.sh train_benchmark.txt mode
bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train bash test_tipc/prepare.sh test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt benchmark_train
``` ```
## 1.2 功能测试 ## 1.2 功能测试
...@@ -33,7 +33,7 @@ dynamic_bs8_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式 ...@@ -33,7 +33,7 @@ dynamic_bs8_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式
## 2. 日志输出 ## 2. 日志输出
运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt` 参数文件的训练日志解析结果是: 运行后将保存模型的训练日志和解析日志,使用 `test_tipc/configs/det_mv3_db_v2_0/train_infer_python.txt` 参数文件的训练日志解析结果是:
``` ```
{"model_branch": "dygaph", "model_commit": "7c39a1996b19087737c05d883fd346d2f39dbcc0", "model_name": "det_mv3_db_v2_0_bs8_fp32_SingleP_DP", "batch_size": 8, "fp_item": "fp32", "run_process_type": "SingleP", "run_mode": "DP", "convergence_value": "5.413110", "convergence_key": "loss:", "ips": 19.333, "speed_unit": "samples/s", "device_num": "N1C1", "model_run_time": "0", "frame_commit": "8cc09552473b842c651ead3b9848d41827a3dbab", "frame_version": "0.0.0"} {"model_branch": "dygaph", "model_commit": "7c39a1996b19087737c05d883fd346d2f39dbcc0", "model_name": "det_mv3_db_v2_0_bs8_fp32_SingleP_DP", "batch_size": 8, "fp_item": "fp32", "run_process_type": "SingleP", "run_mode": "DP", "convergence_value": "5.413110", "convergence_key": "loss:", "ips": 19.333, "speed_unit": "samples/s", "device_num": "N1C1", "model_run_time": "0", "frame_commit": "8cc09552473b842c651ead3b9848d41827a3dbab", "frame_version": "0.0.0"}
...@@ -51,3 +51,25 @@ train_log/ ...@@ -51,3 +51,25 @@ train_log/
├── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C1_log ├── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C1_log
└── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C4_log └── PaddleOCR_det_mv3_db_v2_0_bs8_fp32_SingleP_DP_N1C4_log
``` ```
## 3. 各模型单卡性能数据一览
*注:本节中的速度指标均使用单卡(1块Nvidia V100 16G GPU)测得。通常情况下。
|模型名称|配置文件|大数据集 float32 fps |小数据集 float32 fps |diff |大数据集 float16 fps|小数据集 float16 fps| diff | 大数据集大小 | 小数据集大小 |
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
| ch_ppocr_mobile_v2.0_det |[config](../configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt) | 53.836 | 53.343 / 53.914 / 52.785 |0.020940758 | 45.574 | 45.57 / 46.292 / 46.213 | 0.015596647 | 10,000| 2,000|
| ch_ppocr_mobile_v2.0_rec |[config](../configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt) | 2083.311 | 2043.194 / 2066.372 / 2093.317 |0.023944295 | 2153.261 | 2167.561 / 2165.726 / 2155.614| 0.005511725 | 600,000| 160,000|
| ch_ppocr_server_v2.0_det |[config](../configs/ch_ppocr_server_v2.0_det/train_infer_python.txt) | 20.716 | 20.739 / 20.807 / 20.755 |0.003268131 | 20.592 | 20.498 / 20.993 / 20.75| 0.023579288 | 10,000| 2,000|
| ch_ppocr_server_v2.0_rec |[config](../configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt) | 528.56 | 528.386 / 528.991 / 528.391 |0.001143687 | 1189.788 | 1190.007 / 1176.332 / 1192.084| 0.013213834 | 600,000| 160,000|
| ch_PP-OCRv2_det |[config](../configs/ch_PP-OCRv2_det/train_infer_python.txt) | 13.87 | 13.386 / 13.529 / 13.428 |0.010569887 | 17.847 | 17.746 / 17.908 / 17.96| 0.011915367 | 10,000| 2,000|
| ch_PP-OCRv2_rec |[config](../configs/ch_PP-OCRv2_rec/train_infer_python.txt) | 109.248 | 106.32 / 106.318 / 108.587 |0.020895687 | 117.491 | 117.62 / 117.757 / 117.726| 0.001163413 | 140,000| 40,000|
| det_mv3_db_v2.0 |[config](../configs/det_mv3_db_v2_0/train_infer_python.txt) | 61.802 | 62.078 / 61.802 / 62.008 |0.00444602 | 82.947 | 84.294 / 84.457 / 84.005| 0.005351836 | 10,000| 2,000|
| det_r50_vd_db_v2.0 |[config](../configs/det_r50_vd_db_v2.0/train_infer_python.txt) | 29.955 | 29.092 / 29.31 / 28.844 |0.015899011 | 51.097 |50.367 / 50.879 / 50.227| 0.012814717 | 10,000| 2,000|
| det_r50_vd_east_v2.0 |[config](../configs/det_r50_vd_east_v2.0/train_infer_python.txt) | 42.485 | 42.624 / 42.663 / 42.561 |0.00239083 | 67.61 |67.825/ 68.299/ 68.51| 0.00999854 | 10,000| 2,000|
| det_r50_vd_pse_v2.0 |[config](../configs/det_r50_vd_pse_v2.0/train_infer_python.txt) | 16.455 | 16.517 / 16.555 / 16.353 |0.012201752 | 27.02 |27.288 / 27.152 / 27.408| 0.009340339 | 10,000| 2,000|
| rec_mv3_none_bilstm_ctc_v2.0 |[config](../configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt) | 2288.358 | 2291.906 / 2293.725 / 2290.05 |0.001602197 | 2336.17 |2327.042 / 2328.093 / 2344.915| 0.007622025 | 600,000| 160,000|
| PP-Structure-table |[config](../configs/en_table_structure/train_infer_python.txt) | 14.151 | 14.077 / 14.23 / 14.25 |0.012140351 | 16.285 | 16.595 / 16.878 / 16.531 | 0.020559308 | 20,000| 5,000|
| det_r50_dcn_fce_ctw_v2.0 |[config](../configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt) | 14.057 | 14.029 / 14.02 / 14.014 |0.001069214 | 18.298 |18.411 / 18.376 / 18.331| 0.004345228 | 10,000| 2,000|
| ch_PP-OCRv3_det |[config](../configs/ch_PP-OCRv3_det/train_infer_python.txt) | 8.622 | 8.431 / 8.423 / 8.479|0.006604552 | 14.203 |14.346 14.468 14.23| 0.016450097 | 10,000| 2,000|
| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 90.239 | 90.077 / 91.513 / 91.325|0.01569176 | | | | 160,000| 40,000|
...@@ -22,21 +22,89 @@ trainer_list=$(func_parser_value "${lines[14]}") ...@@ -22,21 +22,89 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt pip install -r requirements.txt
if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate rm -rf ./train_data/icdar2015
cd ./train_data/ && tar xf icdar2015.tar && cd ../ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
if [[ ${model_name} =~ "ch_ppocr_server_v2_0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi fi
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then if [[ ${model_name} =~ "ch_PP-OCRv2_det" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_server_v2.0_det_train.tar && cd ../
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
if [[ ${model_name} =~ "det_r50_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate rm -rf ./train_data/icdar2015
cd ./train_data/ && tar xf icdar2015.tar && cd ../ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi fi
if [[ ${model_name} =~ "det_r18_db_v2_0" ]];then if [[ ${model_name} =~ "det_r18_db_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate rm -rf ./train_data/icdar2015
cd ./train_data/ && tar xf icdar2015.tar && cd ../ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_rec" || ${model_name} =~ "ch_ppocr_server_v2_0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2_0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
rm -rf ./train_data/ic15_data
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data_benckmark.tar
ln -s ./ic15_data_benckmark ./ic15_data
cd ../
fi
if [[ ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
rm -rf ./train_data/ic15_data
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data_benckmark.tar
ln -s ./ic15_data_benckmark ./ic15_data
cd ic15_data
mv rec_gt_train4w.txt rec_gt_train.txt
cd ../
cd ../
fi
if [[ ${model_name} == "en_table_structure" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
rm -rf ./train_data/pubtabnet
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf pubtabnet_benckmark.tar
ln -s ./pubtabnet_benckmark ./pubtabnet
cd ../
fi
if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi fi
fi fi
...@@ -52,13 +120,20 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -52,13 +120,20 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_PP-OCRv3_det_distill_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_PP-OCRv3_det_distill_train.tar && cd ../
fi fi
if [ ${model_name} == "en_table_structure" ];then if [ ${model_name} == "en_table_structure" ] || [ ${model_name} == "en_table_structure_PACT" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../ cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi fi
if [[ ${model_name} =~ "det_r50_db_plusplus" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
fi
if [ ${model_name} == "table_master" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf table_structure_tablemaster_train.tar && cd ../
fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
rm -rf ./train_data/ic15_data rm -rf ./train_data/ic15_data
...@@ -96,7 +171,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -96,7 +171,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../ cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../ cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
...@@ -107,19 +182,31 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -107,19 +182,31 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_db_v2.0" ]; then if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../ cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi fi
if [ ${model_name} == "det_mv3_east_v2.0" ]; then if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../ cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_east_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
fi
if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
if [ ${model_name} == "rec_r32_gaspin_bilstm_att" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
...@@ -147,7 +234,7 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then ...@@ -147,7 +234,7 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../ cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../ cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
...@@ -191,32 +278,32 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -191,32 +278,32 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train" eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_prune_infer" eval_model_name="ch_ppocr_mobile_v2.0_det_prune_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_slim_infer" eval_model_name="ch_ppocr_mobile_v2.0_rec_slim_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
eval_model_name="ch_PP-OCRv2_rec_infer" eval_model_name="ch_PP-OCRv2_rec_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../ cd ./inference && tar xf ${eval_model_name}.tar && cd ../
...@@ -261,39 +348,39 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -261,39 +348,39 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "rec_mv3_none_none_ctc_v2.0" ]; then if [ ${model_name} == "rec_mv3_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_none_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_mv3_none_none_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2.0" ]; then if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_none_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_r34_vd_none_none_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2.0" ]; then if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2.0" ]; then if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_bilstm_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_r34_vd_none_bilstm_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2.0" ]; then if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_mv3_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2.0" ]; then if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "ch_ppocr_server_v2.0_rec" ]; then if [ ${model_name} == "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_server_v2.0_rec_train.tar && cd ../ cd ./inference/ && tar xf ch_ppocr_server_v2.0_rec_train.tar && cd ../
fi fi
if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec" ]; then if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../ cd ./inference/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi fi
...@@ -301,11 +388,11 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -301,11 +388,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mtb_nrtr_train.tar && cd ../ cd ./inference/ && tar xf rec_mtb_nrtr_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2.0" ]; then if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_att_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_mv3_tps_bilstm_att_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2.0" ]; then if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_att_v2.0_train.tar && cd ../ cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_att_v2.0_train.tar && cd ../
fi fi
...@@ -318,7 +405,7 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -318,7 +405,7 @@ elif [ ${MODE} = "whole_infer" ];then
cd ./inference/ && tar xf rec_r50_vd_srn_train.tar && cd ../ cd ./inference/ && tar xf rec_r50_vd_srn_train.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_totaltext_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_sast_totaltext_v2.0_train.tar && cd ../
fi fi
...@@ -326,11 +413,11 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -326,11 +413,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_r50_db_v2.0" ]; then if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} == "det_mv3_pse_v2.0" ]; then if [ ${model_name} == "det_mv3_pse_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_pse_v2.0_train.tar & cd ../ cd ./inference/ && tar xf det_mv3_pse_v2.0_train.tar & cd ../
fi fi
...@@ -338,7 +425,7 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -338,7 +425,7 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_pse_v2.0_train.tar & cd ../ cd ./inference/ && tar xf det_r50_vd_pse_v2.0_train.tar & cd ../
fi fi
if [ ${model_name} == "det_mv3_east_v2.0" ]; then if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_east_v2.0_train.tar & cd ../ cd ./inference/ && tar xf det_mv3_east_v2.0_train.tar & cd ../
fi fi
...@@ -346,6 +433,10 @@ elif [ ${MODE} = "whole_infer" ];then ...@@ -346,6 +433,10 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../ cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
fi fi
if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
if [[ ${model_name} =~ "en_table_structure" ]];then if [[ ${model_name} =~ "en_table_structure" ]];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate
...@@ -357,7 +448,7 @@ fi ...@@ -357,7 +448,7 @@ fi
if [[ ${model_name} =~ "KL" ]]; then if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../ cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
...@@ -389,7 +480,7 @@ if [[ ${model_name} =~ "KL" ]]; then ...@@ -389,7 +480,7 @@ if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
fi fi
if [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
...@@ -407,35 +498,35 @@ if [[ ${model_name} =~ "KL" ]]; then ...@@ -407,35 +498,35 @@ if [[ ${model_name} =~ "KL" ]]; then
fi fi
if [ ${MODE} = "cpp_infer" ];then if [ ${MODE} = "cpp_infer" ];then
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && tar xf rec_inference.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then elif [ ${model_name} = "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
...@@ -487,12 +578,12 @@ if [ ${MODE} = "cpp_infer" ];then ...@@ -487,12 +578,12 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
...@@ -520,7 +611,7 @@ if [ ${MODE} = "serving_infer" ];then ...@@ -520,7 +611,7 @@ if [ ${MODE} = "serving_infer" ];then
${python_name} -m pip install paddle_serving_client ${python_name} -m pip install paddle_serving_client
${python_name} -m pip install paddle-serving-app ${python_name} -m pip install paddle-serving-app
# wget model # wget model
if [ ${model_name} == "ch_ppocr_mobile_v2.0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then if [ ${model_name} == "ch_ppocr_mobile_v2_0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && cd ../
...@@ -532,7 +623,7 @@ if [ ${MODE} = "serving_infer" ];then ...@@ -532,7 +623,7 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_klquant_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_klquant_infer.tar && tar xf ch_PP-OCRv3_rec_klquant_infer.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_det_klquant_infer.tar && tar xf ch_PP-OCRv3_rec_klquant_infer.tar && cd ../
elif [ ${model_name} == "ch_ppocr_mobile_v2.0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then elif [ ${model_name} == "ch_ppocr_mobile_v2_0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && cd ../
...@@ -544,11 +635,11 @@ if [ ${MODE} = "serving_infer" ];then ...@@ -544,11 +635,11 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_pact_infer.tar && tar xf ch_PP-OCRv3_rec_pact_infer.tar && cd ../ cd ./inference && tar xf ch_PP-OCRv3_det_pact_infer.tar && tar xf ch_PP-OCRv3_rec_pact_infer.tar && cd ../
elif [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then elif [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
...@@ -573,11 +664,11 @@ if [ ${MODE} = "paddle2onnx_infer" ];then ...@@ -573,11 +664,11 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
${python_name} -m pip install paddle2onnx ${python_name} -m pip install paddle2onnx
${python_name} -m pip install onnxruntime ${python_name} -m pip install onnxruntime
# wget model # wget model
if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../ cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
......
...@@ -53,7 +53,9 @@ ...@@ -53,7 +53,9 @@
| SRN |rec_r50fpn_vd_none_srn | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - | | SRN |rec_r50fpn_vd_none_srn | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - | | NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - | | SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SPIN |rec_r32_gaspin_bilstm_att | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - | | PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
......
...@@ -139,7 +139,7 @@ if [ ${MODE} = "whole_infer" ]; then ...@@ -139,7 +139,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}_klquant" save_infer_dir="${infer_model}_klquant"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_log_path="${LOG_PATH}/_export_${Count}.log" export_log_path="${LOG_PATH}/${MODE}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]} echo ${infer_run_exports[Count]}
echo $export_cmd echo $export_cmd
......
...@@ -87,11 +87,12 @@ function func_serving(){ ...@@ -87,11 +87,12 @@ function func_serving(){
set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}") set_image_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
python_list=(${python_list}) python_list=(${python_list})
cd ${serving_dir_value} cd ${serving_dir_value}
# cpp serving # cpp serving
for gpu_id in ${gpu_value[*]}; do for gpu_id in ${gpu_value[*]}; do
if [ ${gpu_id} = "null" ]; then if [ ${gpu_id} = "null" ]; then
server_log_path="${LOG_PATH}/cpp_server_cpu.log" server_log_path="${LOG_PATH}/cpp_server_cpu.log"
web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 " web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd eval $web_service_cpp_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cpp_cmd}" "${status_log}" "${model_name}"
...@@ -105,7 +106,7 @@ function func_serving(){ ...@@ -105,7 +106,7 @@ function func_serving(){
ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9 ps ux | grep -i ${port_value} | awk '{print $2}' | xargs kill -s 9
else else
server_log_path="${LOG_PATH}/cpp_server_gpu.log" server_log_path="${LOG_PATH}/cpp_server_gpu.log"
web_service_cpp_cmd="${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > ${server_log_path} 2>&1 " web_service_cpp_cmd="nohup ${python_list[0]} ${web_service_py} --model ${det_server_value} ${rec_server_value} ${op_key} ${op_value} ${port_key} ${port_value} ${gpu_key} ${gpu_id} > ${server_log_path} 2>&1 &"
eval $web_service_cpp_cmd eval $web_service_cpp_cmd
sleep 5s sleep 5s
_save_log_path="${LOG_PATH}/cpp_client_gpu.log" _save_log_path="${LOG_PATH}/cpp_client_gpu.log"
......
...@@ -112,7 +112,7 @@ function func_serving(){ ...@@ -112,7 +112,7 @@ function func_serving(){
cd ${serving_dir_value} cd ${serving_dir_value}
python=${python_list[0]} python=${python_list[0]}
# python serving # python serving
for use_gpu in ${web_use_gpu_list[*]}; do for use_gpu in ${web_use_gpu_list[*]}; do
if [ ${use_gpu} = "null" ]; then if [ ${use_gpu} = "null" ]; then
...@@ -123,19 +123,19 @@ function func_serving(){ ...@@ -123,19 +123,19 @@ function func_serving(){
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
...@@ -174,19 +174,19 @@ function func_serving(){ ...@@ -174,19 +174,19 @@ function func_serving(){
if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "det" ]]; then elif [[ ${model_name} =~ "det" ]]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}") set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
elif [[ ${model_name} =~ "rec" ]]; then elif [[ ${model_name} =~ "rec" ]]; then
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}") set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 " web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
eval $web_service_cmd eval $web_service_cmd
last_status=${PIPESTATUS[0]} last_status=${PIPESTATUS[0]}
status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}" status_check $last_status "${web_service_cmd}" "${status_log}" "${model_name}"
......
...@@ -101,6 +101,7 @@ function func_inference(){ ...@@ -101,6 +101,7 @@ function func_inference(){
_log_path=$4 _log_path=$4
_img_dir=$5 _img_dir=$5
_flag_quant=$6 _flag_quant=$6
_gpu=$7
# inference # inference
for use_gpu in ${use_gpu_list[*]}; do for use_gpu in ${use_gpu_list[*]}; do
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
...@@ -119,7 +120,7 @@ function func_inference(){ ...@@ -119,7 +120,7 @@ function func_inference(){
fi # skip when quant model inference but precision is not int8 fi # skip when quant model inference but precision is not int8
set_precision=$(func_set_params "${precision_key}" "${precision}") set_precision=$(func_set_params "${precision_key}" "${precision}")
_save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log" _save_log_path="${_log_path}/python_infer_cpu_gpus_${_gpu}_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
...@@ -150,7 +151,7 @@ function func_inference(){ ...@@ -150,7 +151,7 @@ function func_inference(){
continue continue
fi fi
for batch_size in ${batch_size_list[*]}; do for batch_size in ${batch_size_list[*]}; do
_save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log" _save_log_path="${_log_path}/python_infer_gpu_gpus_${_gpu}_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}") set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
...@@ -184,6 +185,7 @@ if [ ${MODE} = "whole_infer" ]; then ...@@ -184,6 +185,7 @@ if [ ${MODE} = "whole_infer" ]; then
# set CUDA_VISIBLE_DEVICES # set CUDA_VISIBLE_DEVICES
eval $env eval $env
export Count=0 export Count=0
gpu=0
IFS="|" IFS="|"
infer_run_exports=(${infer_export_list}) infer_run_exports=(${infer_export_list})
infer_quant_flag=(${infer_is_quant}) infer_quant_flag=(${infer_is_quant})
...@@ -193,7 +195,7 @@ if [ ${MODE} = "whole_infer" ]; then ...@@ -193,7 +195,7 @@ if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}" save_infer_dir="${infer_model}"
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_log_path="${LOG_PATH}/_export_${Count}.log" export_log_path="${LOG_PATH}_export_${Count}.log"
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 " export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key} > ${export_log_path} 2>&1 "
echo ${infer_run_exports[Count]} echo ${infer_run_exports[Count]}
echo $export_cmd echo $export_cmd
...@@ -205,7 +207,7 @@ if [ ${MODE} = "whole_infer" ]; then ...@@ -205,7 +207,7 @@ if [ ${MODE} = "whole_infer" ]; then
fi fi
#run inference #run inference
is_quant=${infer_quant_flag[Count]} is_quant=${infer_quant_flag[Count]}
func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} "${gpu}"
Count=$(($Count + 1)) Count=$(($Count + 1))
done done
else else
...@@ -265,7 +267,7 @@ else ...@@ -265,7 +267,7 @@ else
if [ ${run_train} = "null" ]; then if [ ${run_train} = "null" ]; then
continue continue
fi fi
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}") set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}") set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}") set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
...@@ -287,14 +289,15 @@ else ...@@ -287,14 +289,15 @@ else
set_save_model=$(func_set_params "${save_model_key}" "${save_log}") set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} " cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
elif [ ${#ips} -le 15 ];then # train with multi-gpu elif [ ${#ips} -le 15 ];then # train with multi-gpu
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}" cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
else # train with multi-machine else # train with multi-machine
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}" cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
fi fi
# run train # run train
eval $cmd eval $cmd
eval "cat ${save_log}/train.log >> ${save_log}.log"
status_check $? "${cmd}" "${status_log}" "${model_name}" status_check $? "${cmd}" "${status_log}" "${model_name}"
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}")
...@@ -327,7 +330,7 @@ else ...@@ -327,7 +330,7 @@ else
else else
infer_model_dir=${save_infer_path} infer_model_dir=${save_infer_path}
fi fi
func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" "${gpu}"
eval "unset CUDA_VISIBLE_DEVICES" eval "unset CUDA_VISIBLE_DEVICES"
fi fi
......
...@@ -91,12 +91,28 @@ def export_single_model(model, ...@@ -91,12 +91,28 @@ def export_single_model(model,
] ]
# print([None, 3, 32, 128]) # print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "NRTR": elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image
]
if arch_config["algorithm"] == "LayoutLM":
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec": if arch_config["model_type"] == "rec":
...@@ -110,6 +126,8 @@ def export_single_model(model, ...@@ -110,6 +126,8 @@ def export_single_model(model,
infer_shape[-1] = 100 infer_shape[-1] = 100
elif arch_config["model_type"] == "table": elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488] infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
...@@ -172,7 +190,7 @@ def main(): ...@@ -172,7 +190,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model) load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval() model.eval()
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -67,6 +67,23 @@ class TextDetector(object): ...@@ -67,6 +67,23 @@ class TextDetector(object):
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "DB++":
postprocess_params['name'] = 'DBPostProcess'
postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
pre_process_list[1] = {
'NormalizeImage': {
'std': [1.0, 1.0, 1.0],
'mean':
[0.48109378172549, 0.45752457890196, 0.40787054090196],
'scale': '1./255.',
'order': 'hwc'
}
}
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
...@@ -231,7 +248,7 @@ class TextDetector(object): ...@@ -231,7 +248,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1] preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2] preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3] preds['f_tvo'] = outputs[3]
elif self.det_algorithm in ['DB', 'PSE']: elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
preds['maps'] = outputs[0] preds['maps'] = outputs[0]
elif self.det_algorithm == 'FCE': elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
......
...@@ -81,6 +81,12 @@ class TextRecognizer(object): ...@@ -81,6 +81,12 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == "SPIN":
postprocess_params = {
'name': 'SPINLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -258,6 +264,22 @@ class TextRecognizer(object): ...@@ -258,6 +264,22 @@ class TextRecognizer(object):
return padding_im, resize_shape, pad_shape, valid_ratio return padding_im, resize_shape, pad_shape, valid_ratio
def resize_norm_img_spin(self, img):
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
img = np.array(img, np.float32)
img = np.expand_dims(img, -1)
img = img.transpose((2, 0, 1))
mean = [127.5]
std = [127.5]
mean = np.array(mean, dtype=np.float32)
std = np.array(std, dtype=np.float32)
mean = np.float32(mean.reshape(1, -1))
stdinv = 1 / np.float32(std.reshape(1, -1))
img -= mean
img *= stdinv
return img
def resize_norm_img_svtr(self, img, image_shape): def resize_norm_img_svtr(self, img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
...@@ -337,6 +359,10 @@ class TextRecognizer(object): ...@@ -337,6 +359,10 @@ class TextRecognizer(object):
self.rec_image_shape) self.rec_image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
elif self.rec_algorithm == 'SPIN':
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
elif self.rec_algorithm == "ABINet": elif self.rec_algorithm == "ABINet":
norm_img = self.resize_norm_img_abinet( norm_img = self.resize_norm_img_abinet(
img_list[indices[ino]], self.rec_image_shape) img_list[indices[ino]], self.rec_image_shape)
......
...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger): ...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'table': elif mode == 'table':
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger): ...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: if mode in ['ser', 're']:
input_tensor = predictor.get_input_handle(name) input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
else:
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor) output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config return predictor, input_tensor, output_tensors, config
......
...@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path): ...@@ -44,7 +44,7 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
import cv2 import cv2
src_im = img src_im = img
for box in dt_boxes: for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2)) box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
...@@ -106,7 +106,7 @@ def main(): ...@@ -106,7 +106,7 @@ def main():
dt_boxes_list = [] dt_boxes_list = []
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist() tmp_json['points'] = np.array(box).tolist()
dt_boxes_list.append(tmp_json) dt_boxes_list.append(tmp_json)
det_box_json[k] = dt_boxes_list det_box_json[k] = dt_boxes_list
save_det_path = os.path.dirname(config['Global'][ save_det_path = os.path.dirname(config['Global'][
...@@ -118,7 +118,7 @@ def main(): ...@@ -118,7 +118,7 @@ def main():
# write result # write result
for box in boxes: for box in boxes:
tmp_json = {"transcription": ""} tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist() tmp_json['points'] = np.array(box).tolist()
dt_boxes_json.append(tmp_json) dt_boxes_json.append(tmp_json)
save_det_path = os.path.dirname(config['Global'][ save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/" 'save_res_path']) + "/det_results/"
......
...@@ -39,13 +39,12 @@ import time ...@@ -39,13 +39,12 @@ import time
def read_class_list(filepath): def read_class_list(filepath):
dict = {} ret = {}
with open(filepath, "r") as f: with open(filepath, "r") as f:
lines = f.readlines() lines = f.readlines()
for line in lines: for idx, line in enumerate(lines):
key, value = line.split(" ") ret[idx] = line.strip("\n")
dict[key] = value.rstrip() return ret
return dict
def draw_kie_result(batch, node, idx_to_cls, count): def draw_kie_result(batch, node, idx_to_cls, count):
...@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count): ...@@ -71,7 +70,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
x_min = int(min([point[0] for point in new_box])) x_min = int(min([point[0] for point in new_box]))
y_min = int(min([point[1] for point in new_box])) y_min = int(min([point[1] for point in new_box]))
pred_label = str(node_pred_label[i]) pred_label = node_pred_label[i]
if pred_label in idx_to_cls: if pred_label in idx_to_cls:
pred_label = idx_to_cls[pred_label] pred_label = idx_to_cls[pred_label]
pred_score = '{:.2f}'.format(node_pred_score[i]) pred_score = '{:.2f}'.format(node_pred_score[i])
...@@ -109,8 +108,7 @@ def main(): ...@@ -109,8 +108,7 @@ def main():
save_res_path = config['Global']['save_res_path'] save_res_path = config['Global']['save_res_path']
class_path = config['Global']['class_path'] class_path = config['Global']['class_path']
idx_to_cls = read_class_list(class_path) idx_to_cls = read_class_list(class_path)
if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
......
...@@ -36,10 +36,12 @@ from ppocr.modeling.architectures import build_model ...@@ -36,10 +36,12 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
from ppocr.utils.visual import draw_rectangle
import tools.program as program import tools.program as program
import cv2 import cv2
@paddle.no_grad()
def main(config, device, logger, vdl_writer): def main(config, device, logger, vdl_writer):
global_config = config['Global'] global_config = config['Global']
...@@ -53,53 +55,61 @@ def main(config, device, logger, vdl_writer): ...@@ -53,53 +55,61 @@ def main(config, device, logger, vdl_writer):
getattr(post_process_class, 'character')) getattr(post_process_class, 'character'))
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model) load_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']: for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0] op_name = list(op)[0]
if 'Label' in op_name: if 'Encode' in op_name:
continue continue
if op_name == 'KeepKeys': if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image', 'shape']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
os.makedirs(save_res_path, exist_ok=True)
model.eval() model.eval()
for file in get_image_file_list(config['Global']['infer_img']): with open(
logger.info("infer_img: {}".format(file)) os.path.join(save_res_path, 'infer.txt'), mode='w',
with open(file, 'rb') as f: encoding='utf-8') as f_w:
img = f.read() for file in get_image_file_list(config['Global']['infer_img']):
data = {'image': img} logger.info("infer_img: {}".format(file))
batch = transform(data, ops) with open(file, 'rb') as f:
images = np.expand_dims(batch[0], axis=0) img = f.read()
images = paddle.to_tensor(images) data = {'image': img}
preds = model(images) batch = transform(data, ops)
post_result = post_process_class(preds) images = np.expand_dims(batch[0], axis=0)
res_html_code = post_result['res_html_code'] shape_list = np.expand_dims(batch[1], axis=0)
res_loc = post_result['res_loc']
img = cv2.imread(file) images = paddle.to_tensor(images)
imgh, imgw = img.shape[0:2] preds = model(images)
res_loc_final = [] post_result = post_process_class(preds, [shape_list])
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno] structure_str_list = post_result['structure_batch_list'][0]
left = max(int(imgw * x0), 0) bbox_list = post_result['bbox_batch_list'][0]
top = max(int(imgh * y0), 0) structure_str_list = structure_str_list[0]
right = min(int(imgw * x1), imgw - 1) structure_str_list = [
bottom = min(int(imgh * y1), imgh - 1) '<html>', '<body>', '<table>'
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) ] + structure_str_list + ['</table>', '</body>', '</html>']
res_loc_final.append([left, top, right, bottom]) bbox_list_str = json.dumps(bbox_list.tolist())
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final)) logger.info("result: {}, {}".format(structure_str_list,
logger.info("success!") bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(file, bbox_list, use_xywh)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!")
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -44,6 +44,7 @@ def to_tensor(data): ...@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict from collections import defaultdict
data_dict = defaultdict(list) data_dict = defaultdict(list)
to_tensor_idxs = [] to_tensor_idxs = []
for idx, v in enumerate(data): for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs: if idx not in to_tensor_idxs:
...@@ -57,6 +58,7 @@ def to_tensor(data): ...@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object): class SerPredictor(object):
def __init__(self, config): def __init__(self, config):
global_config = config['Global'] global_config = config['Global']
self.algorithm = config['Architecture']["algorithm"]
# build post process # build post process
self.post_process_class = build_post_process(config['PostProcess'], self.post_process_class = build_post_process(config['PostProcess'],
...@@ -70,7 +72,10 @@ class SerPredictor(object): ...@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops # create data ops
transforms = [] transforms = []
...@@ -80,29 +85,30 @@ class SerPredictor(object): ...@@ -80,29 +85,30 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = [
'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'token_type_ids', 'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True if config["Global"].get("infer_mode", None) is None:
global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'], self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config) global_config)
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, data):
with open(img_path, 'rb') as f: with open(data["img_path"], 'rb') as f:
img = f.read() img = f.read()
data = {'image': img} data["image"] = img
batch = transform(data, self.ops) batch = transform(data, self.ops)
batch = to_tensor(batch) batch = to_tensor(batch)
preds = self.model(batch) preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
attention_masks=batch[4],
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch return post_result, batch
...@@ -112,20 +118,33 @@ if __name__ == '__main__': ...@@ -112,20 +118,33 @@ if __name__ == '__main__':
ser_engine = SerPredictor(config) ser_engine = SerPredictor(config)
infer_imgs = get_image_file_list(config['Global']['infer_img']) if config["Global"].get("infer_mode", None) is False:
data_dir = config['Eval']['dataset']['data_dir']
with open(config['Global']['infer_img'], "rb") as f:
infer_imgs = f.readlines()
else:
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open( with open(
os.path.join(config['Global']['save_res_path'], os.path.join(config['Global']['save_res_path'],
"infer_results.txt"), "infer_results.txt"),
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
data_line = info.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
data = {'img_path': img_path, 'label': substr[1]}
else:
img_path = info
data = {'img_path': img_path}
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(img_path) result, _ = ser_engine(data)
result = result[0] result = result[0]
fout.write(img_path + "\t" + json.dumps( fout.write(img_path + "\t" + json.dumps(
{ {
...@@ -133,3 +152,6 @@ if __name__ == '__main__': ...@@ -133,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result) img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model ...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config, check_gpu from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor from tools.infer_vqa_token_ser import SerPredictor
...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results): ...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input # remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7) ser_inputs.pop(7)
ser_inputs.pop(6) ser_inputs.pop(6)
ser_inputs.pop(1) ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch return ser_inputs, entity_idx_dict_batch
...@@ -131,9 +131,7 @@ class SerRePredictor(object): ...@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine(img_path) ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
paddle.save(ser_inputs, 'ser_inputs.npy')
paddle.save(ser_results, 'ser_results.npy')
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input) preds = self.model(re_input)
post_result = self.post_process_class( post_result = self.post_process_class(
...@@ -155,7 +153,6 @@ def preprocess(): ...@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
...@@ -185,9 +182,7 @@ if __name__ == '__main__': ...@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result = ser_re_engine(img_path) result = ser_re_engine(img_path)
result = result[0] result = result[0]
...@@ -197,3 +192,6 @@ if __name__ == '__main__': ...@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result) img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
...@@ -154,6 +154,24 @@ def check_xpu(use_xpu): ...@@ -154,6 +154,24 @@ def check_xpu(use_xpu):
except Exception as e: except Exception as e:
pass pass
def to_float32(preds):
if isinstance(preds, dict):
for k in preds:
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
preds[k] = to_float32(preds[k])
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
preds[k] = preds[k].astype(paddle.float32)
else:
preds = preds.astype(paddle.float32)
return preds
def train(config, def train(config,
train_dataloader, train_dataloader,
...@@ -207,7 +225,7 @@ def train(config, ...@@ -207,7 +225,7 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR"] extra_input_models = ["SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN"]
extra_input = False extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
...@@ -252,13 +270,19 @@ def train(config, ...@@ -252,13 +270,19 @@ def train(config,
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(): with paddle.amp.auto_cast(level='O2'):
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type in ["kie", 'vqa']: elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
preds = to_float32(preds)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
...@@ -266,23 +290,19 @@ def train(config, ...@@ -266,23 +290,19 @@ def train(config,
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch)
loss = loss_class(preds, batch) avg_loss = loss['loss']
avg_loss = loss['loss']
if scaler:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
else:
avg_loss.backward() avg_loss.backward()
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']: if model_type in ['kie']:
eval_class(preds, batch) eval_class(preds, batch)
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
else: else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2' if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss ]: # for multi head loss
...@@ -463,7 +483,6 @@ def eval(model, ...@@ -463,7 +483,6 @@ def eval(model,
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
batch_numpy = [] batch_numpy = []
for item in batch: for item in batch:
if isinstance(item, paddle.Tensor): if isinstance(item, paddle.Tensor):
...@@ -473,9 +492,9 @@ def eval(model, ...@@ -473,9 +492,9 @@ def eval(model,
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
if model_type in ['table', 'kie']: if model_type in ['kie']:
eval_class(preds, batch_numpy) eval_class(preds, batch_numpy)
elif model_type in ['vqa']: elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy) post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy) eval_class(post_result, batch_numpy)
else: else:
...@@ -576,8 +595,8 @@ def preprocess(is_train=False): ...@@ -576,8 +595,8 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'ViTSTR', 'ABINet' 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN'
] ]
if use_xpu: if use_xpu:
......
...@@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer): ...@@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1 config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
model = apply_to_static(model, config, logger) model = apply_to_static(model, config, logger)
# build loss # build loss
...@@ -157,9 +154,13 @@ def main(config, device, logger, vdl_writer): ...@@ -157,9 +154,13 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler( scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss, init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling) use_dynamic_loss_scaling=use_dynamic_loss_scaling)
model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level='O2', master_weight=True)
else: else:
scaler = None scaler = None
if config['Global']['distributed']:
model = paddle.DataParallel(model)
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册