提交 41a1b292 编写于 作者: qq_25193841's avatar qq_25193841

Merge remote-tracking branch 'origin/dygraph' into dygraph

...@@ -61,7 +61,7 @@ from combobox import ComboBox ...@@ -61,7 +61,7 @@ from combobox import ComboBox
from libs.constants import * from libs.constants import *
from libs.utils import * from libs.utils import *
from libs.settings import Settings from libs.settings import Settings
from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR,DEFAULT_LOCK_COLOR
from libs.stringBundle import StringBundle from libs.stringBundle import StringBundle
from libs.canvas import Canvas from libs.canvas import Canvas
from libs.zoomWidget import ZoomWidget from libs.zoomWidget import ZoomWidget
...@@ -101,6 +101,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -101,6 +101,8 @@ class MainWindow(QMainWindow, WindowMixin):
def __init__(self, lang="ch", gpu=False, defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None): def __init__(self, lang="ch", gpu=False, defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None):
super(MainWindow, self).__init__() super(MainWindow, self).__init__()
self.setWindowTitle(__appname__) self.setWindowTitle(__appname__)
self.setWindowState(Qt.WindowMaximized) # set window max
self.activateWindow() # PPOCRLabel goes to the front when activate
# Load setting in the main thread # Load setting in the main thread
self.settings = Settings() self.settings = Settings()
...@@ -126,7 +128,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -126,7 +128,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelHist = [] self.labelHist = []
self.lastOpenDir = None self.lastOpenDir = None
self.result_dic = [] self.result_dic = []
self.result_dic_locked = []
self.changeFileFolder = False self.changeFileFolder = False
self.haveAutoReced = False self.haveAutoReced = False
self.labelFile = None self.labelFile = None
...@@ -178,7 +180,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -178,7 +180,8 @@ class MainWindow(QMainWindow, WindowMixin):
fileListContainer = QWidget() fileListContainer = QWidget()
fileListContainer.setLayout(filelistLayout) fileListContainer.setLayout(filelistLayout)
self.filedock = QDockWidget(getStr('fileList'), self) self.fileListName = getStr('fileList')
self.filedock = QDockWidget(self.fileListName, self)
self.filedock.setObjectName(getStr('files')) self.filedock.setObjectName(getStr('files'))
self.filedock.setWidget(fileListContainer) self.filedock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.filedock) self.addDockWidget(Qt.LeftDockWidgetArea, self.filedock)
...@@ -394,7 +397,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -394,7 +397,8 @@ class MainWindow(QMainWindow, WindowMixin):
'w', 'objects', getStr('crtBoxDetail'), enabled=False) 'w', 'objects', getStr('crtBoxDetail'), enabled=False)
delete = action(getStr('delBox'), self.deleteSelectedShape, delete = action(getStr('delBox'), self.deleteSelectedShape,
'backspace', 'delete', getStr('delBoxDetail'), enabled=False) 'Alt+X', 'delete', getStr('delBoxDetail'), enabled=False)
copy = action(getStr('dupBox'), self.copySelectedShape, copy = action(getStr('dupBox'), self.copySelectedShape,
'Ctrl+C', 'copy', getStr('dupBoxDetail'), 'Ctrl+C', 'copy', getStr('dupBoxDetail'),
enabled=False) enabled=False)
...@@ -405,6 +409,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -405,6 +409,7 @@ class MainWindow(QMainWindow, WindowMixin):
showAll = action(getStr('showBox'), partial(self.togglePolygons, True), showAll = action(getStr('showBox'), partial(self.togglePolygons, True),
'Ctrl+A', 'hide', getStr('showAllBoxDetail'), 'Ctrl+A', 'hide', getStr('showAllBoxDetail'),
enabled=False) enabled=False)
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail')) help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info')) showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
...@@ -476,6 +481,10 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -476,6 +481,10 @@ class MainWindow(QMainWindow, WindowMixin):
undo = action(getStr("undo"), self.undoShapeEdit, undo = action(getStr("undo"), self.undoShapeEdit,
'Ctrl+Z', "undo", getStr("undo"), enabled=False) 'Ctrl+Z', "undo", getStr("undo"), enabled=False)
lock = action(getStr("lockBox"), self.lockSelectedShape,
None, "lock", getStr("lockBoxDetail"),
enabled=False)
self.editButton.setDefaultAction(edit) self.editButton.setDefaultAction(edit)
self.newButton.setDefaultAction(create) self.newButton.setDefaultAction(create)
...@@ -538,13 +547,13 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -538,13 +547,13 @@ class MainWindow(QMainWindow, WindowMixin):
fitWindow=fitWindow, fitWidth=fitWidth, fitWindow=fitWindow, fitWidth=fitWidth,
zoomActions=zoomActions, saveLabel=saveLabel, zoomActions=zoomActions, saveLabel=saveLabel,
undo=undo, undoLastPoint=undoLastPoint,open_dataset_dir=open_dataset_dir, undo=undo, undoLastPoint=undoLastPoint,open_dataset_dir=open_dataset_dir,
rotateLeft=rotateLeft,rotateRight=rotateRight, rotateLeft=rotateLeft,rotateRight=rotateRight,lock=lock,
fileMenuActions=( fileMenuActions=(
opendir, open_dataset_dir, saveLabel, resetAll, quit), opendir, open_dataset_dir, saveLabel, resetAll, quit),
beginner=(), advanced=(), beginner=(), advanced=(),
editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint, editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint,
None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption), None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption,lock),
beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,), beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,lock),
advancedContext=(createMode, editMode, edit, copy, advancedContext=(createMode, editMode, edit, copy,
delete, shapeLineColor, shapeFillColor), delete, shapeLineColor, shapeFillColor),
onLoadActive=( onLoadActive=(
...@@ -998,6 +1007,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -998,6 +1007,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.delete.setEnabled(n_selected) self.actions.delete.setEnabled(n_selected)
self.actions.copy.setEnabled(n_selected) self.actions.copy.setEnabled(n_selected)
self.actions.edit.setEnabled(n_selected == 1) self.actions.edit.setEnabled(n_selected == 1)
self.actions.lock.setEnabled(n_selected)
def addLabel(self, shape): def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked() shape.paintLabel = self.displayLabelOption.isChecked()
...@@ -1041,7 +1051,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1041,7 +1051,7 @@ class MainWindow(QMainWindow, WindowMixin):
def loadLabels(self, shapes): def loadLabels(self, shapes):
s = [] s = []
for label, points, line_color, fill_color, difficult in shapes: for label, points, line_color, fill_color, difficult in shapes:
shape = Shape(label=label) shape = Shape(label=label,line_color=line_color)
for x, y in points: for x, y in points:
# Ensure the labels are within the bounds of the image. If not, fix them. # Ensure the labels are within the bounds of the image. If not, fix them.
...@@ -1051,6 +1061,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1051,6 +1061,7 @@ class MainWindow(QMainWindow, WindowMixin):
shape.addPoint(QPointF(x, y)) shape.addPoint(QPointF(x, y))
shape.difficult = difficult shape.difficult = difficult
#shape.locked = False
shape.close() shape.close()
s.append(shape) s.append(shape)
...@@ -1063,10 +1074,12 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1063,10 +1074,12 @@ class MainWindow(QMainWindow, WindowMixin):
# shape.fill_color = QColor(*fill_color) # shape.fill_color = QColor(*fill_color)
# else: # else:
# shape.fill_color = generateColorByText(label) # shape.fill_color = generateColorByText(label)
self.addLabel(shape) self.addLabel(shape)
self.updateComboBox() self.updateComboBox()
self.canvas.loadShapes(s) self.canvas.loadShapes(s)
def singleLabel(self, shape): def singleLabel(self, shape):
if shape is None: if shape is None:
...@@ -1106,10 +1119,9 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1106,10 +1119,9 @@ class MainWindow(QMainWindow, WindowMixin):
difficult=s.difficult) # bool difficult=s.difficult) # bool
shapes = [] if mode == 'Auto' else \ shapes = [] if mode == 'Auto' else \
[format_shape(shape) for shape in self.canvas.shapes] [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
# Can add differrent annotation formats here # Can add differrent annotation formats here
for box in self.result_dic :
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False} trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
if trans_dic["label"] == "" and mode == 'Auto': if trans_dic["label"] == "" and mode == 'Auto':
continue continue
...@@ -1120,7 +1132,6 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1120,7 +1132,6 @@ class MainWindow(QMainWindow, WindowMixin):
for box in shapes: for box in shapes:
trans_dic.append({"transcription": box['label'], "points": box['points'], 'difficult': box['difficult']}) trans_dic.append({"transcription": box['label'], "points": box['points'], 'difficult': box['difficult']})
self.PPlabel[annotationFilePath] = trans_dic self.PPlabel[annotationFilePath] = trans_dic
if mode == 'Auto': if mode == 'Auto':
self.Cachelabel[annotationFilePath] = trans_dic self.Cachelabel[annotationFilePath] = trans_dic
...@@ -1313,6 +1324,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1313,6 +1324,7 @@ class MainWindow(QMainWindow, WindowMixin):
# unicodeFilePath = os.path.abspath(unicodeFilePath) # unicodeFilePath = os.path.abspath(unicodeFilePath)
# Tzutalin 20160906 : Add file list and dock to move faster # Tzutalin 20160906 : Add file list and dock to move faster
# Highlight the file item # Highlight the file item
if unicodeFilePath and self.fileListWidget.count() > 0: if unicodeFilePath and self.fileListWidget.count() > 0:
if unicodeFilePath in self.mImgList: if unicodeFilePath in self.mImgList:
index = self.mImgList.index(unicodeFilePath) index = self.mImgList.index(unicodeFilePath)
...@@ -1322,6 +1334,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1322,6 +1334,7 @@ class MainWindow(QMainWindow, WindowMixin):
### ###
self.iconlist.clear() self.iconlist.clear()
self.additems5(None) self.additems5(None)
for i in range(5): for i in range(5):
item_tooltip = self.iconlist.item(i).toolTip() item_tooltip = self.iconlist.item(i).toolTip()
# print(i,"---",item_tooltip) # print(i,"---",item_tooltip)
...@@ -1340,7 +1353,6 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1340,7 +1353,6 @@ class MainWindow(QMainWindow, WindowMixin):
if unicodeFilePath and os.path.exists(unicodeFilePath): if unicodeFilePath and os.path.exists(unicodeFilePath):
self.canvas.verified = False self.canvas.verified = False
cvimg = cv2.imdecode(np.fromfile(unicodeFilePath, dtype=np.uint8), 1) cvimg = cv2.imdecode(np.fromfile(unicodeFilePath, dtype=np.uint8), 1)
height, width, depth = cvimg.shape height, width, depth = cvimg.shape
cvimg = cv2.cvtColor(cvimg, cv2.COLOR_BGR2RGB) cvimg = cv2.cvtColor(cvimg, cv2.COLOR_BGR2RGB)
...@@ -1361,34 +1373,52 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1361,34 +1373,52 @@ class MainWindow(QMainWindow, WindowMixin):
else: else:
self.dirty = False self.dirty = False
self.actions.save.setEnabled(True) self.actions.save.setEnabled(True)
if len(self.canvas.lockedShapes) != 0:
self.actions.save.setEnabled(True)
self.setDirty()
self.canvas.setEnabled(True) self.canvas.setEnabled(True)
self.adjustScale(initial=True) self.adjustScale(initial=True)
self.paintCanvas() self.paintCanvas()
self.addRecentFile(self.filePath) self.addRecentFile(self.filePath)
self.toggleActions(True) self.toggleActions(True)
self.showBoundingBoxFromPPlabel(filePath) self.showBoundingBoxFromPPlabel(filePath)
self.setWindowTitle(__appname__ + ' ' + filePath) self.setWindowTitle(__appname__ + ' ' + filePath)
# Default : select last item if there is at least one item # Default : select last item if there is at least one item
if self.labelList.count(): if self.labelList.count():
self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1)) self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
self.labelList.item(self.labelList.count() - 1).setSelected(True) self.labelList.item(self.labelList.count() - 1).setSelected(True)
# show file list image count
select_indexes = self.fileListWidget.selectedIndexes()
if len(select_indexes) > 0:
self.filedock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
f"/{self.fileListWidget.count()})")
self.canvas.setFocus(True) self.canvas.setFocus(True)
return True return True
return False return False
def showBoundingBoxFromPPlabel(self, filePath): def showBoundingBoxFromPPlabel(self, filePath):
width, height = self.image.width(), self.image.height()
imgidx = self.getImglabelidx(filePath) imgidx = self.getImglabelidx(filePath)
if imgidx not in self.PPlabel.keys(): shapes =[]
return #box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
shapes = [] # four corner coordinates of the shapes to the height and width of the image
for box in self.PPlabel[imgidx]: for box in self.canvas.lockedShapes:
shapes.append((box['transcription'], box['points'], None, None, box['difficult'])) if self.canvas.isInTheSameImage:
shapes.append((box['transcription'], [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
else:
shapes.append(('锁定框:待检测', [[s[0]*width,s[1]*height]for s in box['ratio']],
DEFAULT_LOCK_COLOR, None, box['difficult']))
if imgidx in self.PPlabel.keys():
for box in self.PPlabel[imgidx]:
shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
self.loadLabels(shapes) self.loadLabels(shapes)
self.canvas.verified = False self.canvas.verified = False
...@@ -1576,7 +1606,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1576,7 +1606,8 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.rotateLeft.setEnabled(True) self.actions.rotateLeft.setEnabled(True)
self.actions.rotateRight.setEnabled(True) self.actions.rotateRight.setEnabled(True)
self.fileListWidget.setCurrentRow(0) # set list index to first
self.filedock.setWindowTitle(self.fileListName + f" (1/{self.fileListWidget.count()})") # show image count
def openPrevImg(self, _value=False): def openPrevImg(self, _value=False):
if len(self.mImgList) <= 0: if len(self.mImgList) <= 0:
...@@ -1646,9 +1677,37 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1646,9 +1677,37 @@ class MainWindow(QMainWindow, WindowMixin):
else: else:
return fullFilePath return fullFilePath
return '' return ''
def saveLockedShapes(self):
self.canvas.lockedShapes = []
self.canvas.selectedShapes = []
for s in self.canvas.shapes:
if s.line_color == DEFAULT_LOCK_COLOR:
self.canvas.selectedShapes.append(s)
self.lockSelectedShape()
for s in self.canvas.shapes:
if s.line_color == DEFAULT_LOCK_COLOR:
self.canvas.selectedShapes.remove(s)
self.canvas.shapes.remove(s)
def _saveFile(self, annotationFilePath, mode='Manual'): def _saveFile(self, annotationFilePath, mode='Manual'):
if len(self.canvas.lockedShapes) != 0:
self.saveLockedShapes()
if mode == 'Manual': if mode == 'Manual':
self.result_dic_locked = []
img = cv2.imread(self.filePath)
width, height = self.image.width(), self.image.height()
for shape in self.canvas.lockedShapes:
box = [[int(p[0]*width), int(p[1]*height)] for p in shape['ratio']]
assert len(box) == 4
result = [(shape['transcription'],1)]
result.insert(0, box)
self.result_dic_locked.append(result)
self.result_dic += self.result_dic_locked
self.result_dic_locked = []
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode): if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean() self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath) self.statusBar().showMessage('Saved to %s' % annotationFilePath)
...@@ -1663,13 +1722,13 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1663,13 +1722,13 @@ class MainWindow(QMainWindow, WindowMixin):
self.savePPlabel(mode='Auto') self.savePPlabel(mode='Auto')
self.fileListWidget.insertItem(int(currIndex), item) self.fileListWidget.insertItem(int(currIndex), item)
self.openNextImg() if not self.canvas.isInTheSameImage:
self.openNextImg()
self.actions.saveRec.setEnabled(True) self.actions.saveRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True) self.actions.saveLabel.setEnabled(True)
elif mode == 'Auto': elif mode == 'Auto':
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode): if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean() self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath) self.statusBar().showMessage('Saved to %s' % annotationFilePath)
self.statusBar().show() self.statusBar().show()
...@@ -1733,14 +1792,19 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1733,14 +1792,19 @@ class MainWindow(QMainWindow, WindowMixin):
if discardChanges == QMessageBox.No: if discardChanges == QMessageBox.No:
return True return True
elif discardChanges == QMessageBox.Yes: elif discardChanges == QMessageBox.Yes:
self.canvas.isInTheSameImage = True
self.saveFile() self.saveFile()
self.canvas.isInTheSameImage = False
return True return True
else: else:
return False return False
def discardChangesDialog(self): def discardChangesDialog(self):
yes, no, cancel = QMessageBox.Yes, QMessageBox.No, QMessageBox.Cancel yes, no, cancel = QMessageBox.Yes, QMessageBox.No, QMessageBox.Cancel
msg = u'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.' if self.lang == 'ch':
msg = u'您有未保存的变更, 您想保存再继续吗?\n点击 "No" 丢弃所有未保存的变更.'
else:
msg = u'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.'
return QMessageBox.warning(self, u'Attention', msg, yes | no | cancel) return QMessageBox.warning(self, u'Attention', msg, yes | no | cancel)
def errorMessage(self, title, message): def errorMessage(self, title, message):
...@@ -1858,7 +1922,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1858,7 +1922,7 @@ class MainWindow(QMainWindow, WindowMixin):
uncheckedList = [i for i in self.mImgList if i not in self.fileStatedict.keys()] uncheckedList = [i for i in self.mImgList if i not in self.fileStatedict.keys()]
self.autoDialog = AutoDialog(parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList)) self.autoDialog = AutoDialog(parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList))
self.autoDialog.popUp() self.autoDialog.popUp()
self.currIndex=len(self.mImgList) self.currIndex = len(self.mImgList) - 1
self.loadFile(self.filePath) # ADD self.loadFile(self.filePath) # ADD
self.haveAutoReced = True self.haveAutoReced = True
self.AutoRecognition.setEnabled(False) self.AutoRecognition.setEnabled(False)
...@@ -1872,6 +1936,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1872,6 +1936,7 @@ class MainWindow(QMainWindow, WindowMixin):
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]] # org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes: if self.canvas.shapes:
self.result_dic = [] self.result_dic = []
self.result_dic_locked = [] # result_dic_locked stores the ocr result of self.canvas.lockedShapes
rec_flag = 0 rec_flag = 0
for shape in self.canvas.shapes: for shape in self.canvas.shapes:
box = [[int(p.x()), int(p.y())] for p in shape.points] box = [[int(p.x()), int(p.y())] for p in shape.points]
...@@ -1883,21 +1948,32 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1883,21 +1948,32 @@ class MainWindow(QMainWindow, WindowMixin):
return return
result = self.ocr.ocr(img_crop, cls=True, det=False) result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] != '': if result[0][0] != '':
result.insert(0, box) if shape.line_color == DEFAULT_LOCK_COLOR:
print('result in reRec is ', result) shape.label = result[0][0]
self.result_dic.append(result) result.insert(0, box)
self.result_dic_locked.append(result)
else:
result.insert(0, box)
self.result_dic.append(result)
else: else:
print('Can not recognise the box') print('Can not recognise the box')
self.result_dic.append([box,(self.noLabelText,0)]) if shape.line_color == DEFAULT_LOCK_COLOR:
shape.label = result[0][0]
if self.noLabelText == shape.label or result[1][0] == shape.label: self.result_dic_locked.append([box,(self.noLabelText,0)])
print('label no change') else:
else: self.result_dic.append([box,(self.noLabelText,0)])
rec_flag += 1 try:
if self.noLabelText == shape.label or result[1][0] == shape.label:
if len(self.result_dic) > 0 and rec_flag > 0: print('label no change')
else:
rec_flag += 1
except IndexError as e:
print('Can not recognise the box')
if (len(self.result_dic) > 0 and rec_flag > 0)or self.canvas.lockedShapes:
self.canvas.isInTheSameImage = True
self.saveFile(mode='Auto') self.saveFile(mode='Auto')
self.loadFile(self.filePath) self.loadFile(self.filePath)
self.canvas.isInTheSameImage = False
self.setDirty() self.setDirty()
elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0: elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0:
QMessageBox.information(self, "Information", "The recognition result remains unchanged!") QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
...@@ -2027,8 +2103,11 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -2027,8 +2103,11 @@ class MainWindow(QMainWindow, WindowMixin):
f.write(key + '\t') f.write(key + '\t')
f.write(json.dumps(self.PPlabel[key], ensure_ascii=False) + '\n') f.write(json.dumps(self.PPlabel[key], ensure_ascii=False) + '\n')
if mode=='Manual': if mode == 'Manual':
msg = 'Images that have been checked are saved in '+ self.PPlabelpath if self.lang == 'ch':
msg = '已将检查过的图片标签保存在 ' + self.PPlabelpath + " 文件中"
else:
msg = 'Images that have been checked are saved in ' + self.PPlabelpath
QMessageBox.information(self, "Information", msg) QMessageBox.information(self, "Information", msg)
def saveCacheLabel(self): def saveCacheLabel(self):
...@@ -2107,6 +2186,44 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -2107,6 +2186,44 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelList.clearSelection() self.labelList.clearSelection()
self._noSelectionSlot = False self._noSelectionSlot = False
self.canvas.loadShapes(shapes, replace=replace) self.canvas.loadShapes(shapes, replace=replace)
print("loadShapes")#1
def lockSelectedShape(self):
"""lock the selsected shapes.
Add self.selectedShapes to lock self.canvas.lockedShapes,
which holds the ratio of the four coordinates of the locked shapes
to the width and height of the image
"""
width, height = self.image.width(), self.image.height()
def format_shape(s):
return dict(label=s.label, # str
line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(),
ratio=[[int(p.x())/width, int(p.y())/height] for p in s.points], # QPonitF
# add chris
difficult=s.difficult) # bool
#lock
if len(self.canvas.lockedShapes) == 0:
for s in self.canvas.selectedShapes:
s.line_color = DEFAULT_LOCK_COLOR
s.locked = True
shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
trans_dic = []
for box in shapes:
trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']})
self.canvas.lockedShapes = trans_dic
self.actions.save.setEnabled(True)
#unlock
else:
for s in self.canvas.shapes:
s.line_color = DEFAULT_LINE_COLOR
self.canvas.lockedShapes = []
self.result_dic_locked = []
self.setDirty()
self.actions.save.setEnabled(True)
def inverted(color): def inverted(color):
......
...@@ -143,7 +143,7 @@ python PPOCRLabel.py ...@@ -143,7 +143,7 @@ python PPOCRLabel.py
### 3.1 Shortcut keys ### 3.1 Shortcut keys
| Shortcut keys | Description | | Shortcut keys | Description |
| ------------------------ | ------------------------------------------------ | |--------------------------| ------------------------------------------------ |
| Ctrl + Shift + R | Re-recognize all the labels of the current image | | Ctrl + Shift + R | Re-recognize all the labels of the current image |
| W | Create a rect box | | W | Create a rect box |
| Q | Create a four-points box | | Q | Create a four-points box |
...@@ -151,7 +151,7 @@ python PPOCRLabel.py ...@@ -151,7 +151,7 @@ python PPOCRLabel.py
| Ctrl + R | Re-recognize the selected box | | Ctrl + R | Re-recognize the selected box |
| Ctrl + C | Copy and paste the selected box | | Ctrl + C | Copy and paste the selected box |
| Ctrl + Left Mouse Button | Multi select the label box | | Ctrl + Left Mouse Button | Multi select the label box |
| Backspace | Delete the selected box | | Ctrl + X | Delete the selected box |
| Ctrl + V | Check image | | Ctrl + V | Check image |
| Ctrl + Shift + d | Delete image | | Ctrl + Shift + d | Delete image |
| D | Next image | | D | Next image |
......
...@@ -131,16 +131,16 @@ python PPOCRLabel.py --lang ch ...@@ -131,16 +131,16 @@ python PPOCRLabel.py --lang ch
### 3.1 快捷键 ### 3.1 快捷键
| 快捷键 | 说明 | | 快捷键 | 说明 |
| ---------------- | ---------------------------- | |------------------| ---------------------------- |
| Ctrl + shift + R | 对当前图片的所有标记重新识别 | | Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 | | W | 新建矩形框 |
| Q | 新建四点框 | | Q | 新建四点框 |
| Ctrl + E | 编辑所选框标签 | | Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 | | Ctrl + R | 重新识别所选标记 |
| Ctrl + C | 复制并粘贴选中的标记框 | | Ctrl + C | 复制并粘贴选中的标记框 |
| Ctrl + 鼠标左键 | 多选标记框 | | Ctrl + 鼠标左键 | 多选标记框 |
| Backspace | 删除所选框 | | Ctrl + X | 删除所选框 |
| Ctrl + V | 确认本张图片标记 | | Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 | | Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 | | D | 下一张图片 |
......
...@@ -6,6 +6,8 @@ except ImportError: ...@@ -6,6 +6,8 @@ except ImportError:
from PyQt4.QtGui import * from PyQt4.QtGui import *
from PyQt4.QtCore import * from PyQt4.QtCore import *
import time
import datetime
import json import json
import cv2 import cv2
import numpy as np import numpy as np
...@@ -80,8 +82,9 @@ class AutoDialog(QDialog): ...@@ -80,8 +82,9 @@ class AutoDialog(QDialog):
self.parent = parent self.parent = parent
self.ocr = ocr self.ocr = ocr
self.mImgList = mImgList self.mImgList = mImgList
self.lender = lenbar
self.pb = QProgressBar() self.pb = QProgressBar()
self.pb.setRange(0, lenbar) self.pb.setRange(0, self.lender)
self.pb.setValue(0) self.pb.setValue(0)
layout = QVBoxLayout() layout = QVBoxLayout()
...@@ -108,10 +111,16 @@ class AutoDialog(QDialog): ...@@ -108,10 +111,16 @@ class AutoDialog(QDialog):
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal) self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal) self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal) self.thread_1.endsignal.connect(self.handleEndsignalSignal)
self.time_start = time.time() # save start time
def handleProgressBarSingal(self, i): def handleProgressBarSingal(self, i):
self.pb.setValue(i) self.pb.setValue(i)
# calculate time left of auto labeling
avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show
def handleListWidgetSingal(self, i): def handleListWidgetSingal(self, i):
self.listWidget.addItem(i) self.listWidget.addItem(i)
titem = self.listWidget.item(self.listWidget.count() - 1) titem = self.listWidget.item(self.listWidget.count() - 1)
......
...@@ -87,6 +87,10 @@ class Canvas(QWidget): ...@@ -87,6 +87,10 @@ class Canvas(QWidget):
#initialisation for panning #initialisation for panning
self.pan_initial_pos = QPoint() self.pan_initial_pos = QPoint()
#lockedshapes related
self.lockedShapes = []
self.isInTheSameImage = False
def setDrawingColor(self, qColor): def setDrawingColor(self, qColor):
self.drawingLineColor = qColor self.drawingLineColor = qColor
self.drawingRectColor = qColor self.drawingRectColor = qColor
......
...@@ -30,6 +30,7 @@ DEFAULT_SELECT_LINE_COLOR = QColor(255, 255, 255) ...@@ -30,6 +30,7 @@ DEFAULT_SELECT_LINE_COLOR = QColor(255, 255, 255)
DEFAULT_SELECT_FILL_COLOR = QColor(0, 128, 255, 155) DEFAULT_SELECT_FILL_COLOR = QColor(0, 128, 255, 155)
DEFAULT_VERTEX_FILL_COLOR = QColor(0, 255, 0, 255) DEFAULT_VERTEX_FILL_COLOR = QColor(0, 255, 0, 255)
DEFAULT_HVERTEX_FILL_COLOR = QColor(255, 0, 0) DEFAULT_HVERTEX_FILL_COLOR = QColor(255, 0, 0)
DEFAULT_LOCK_COLOR = QColor(255, 0, 255)
MIN_Y_LABEL = 10 MIN_Y_LABEL = 10
...@@ -57,7 +58,7 @@ class Shape(object): ...@@ -57,7 +58,7 @@ class Shape(object):
self.selected = False self.selected = False
self.difficult = difficult self.difficult = difficult
self.paintLabel = paintLabel self.paintLabel = paintLabel
self.locked = False
self._highlightIndex = None self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX self._highlightMode = self.NEAR_VERTEX
self._highlightSettings = { self._highlightSettings = {
......
...@@ -104,4 +104,6 @@ singleRe=Re-recognition RectBox ...@@ -104,4 +104,6 @@ singleRe=Re-recognition RectBox
labelDialogOption=Pop-up Label Input Dialog labelDialogOption=Pop-up Label Input Dialog
undo=Undo undo=Undo
undoLastPoint=Undo Last Point undoLastPoint=Undo Last Point
autoSaveMode=Auto Export Label Mode autoSaveMode=Auto Export Label Mode
\ No newline at end of file lockBox=Lock selected box/Unlock all box
lockBoxDetail=Lock selected box/Unlock all box
\ No newline at end of file
...@@ -104,4 +104,6 @@ singleRe=重识别此区块 ...@@ -104,4 +104,6 @@ singleRe=重识别此区块
labelDialogOption=弹出标记输入框 labelDialogOption=弹出标记输入框
undo=撤销 undo=撤销
undoLastPoint=撤销上个点 undoLastPoint=撤销上个点
autoSaveMode=自动导出标记结果 autoSaveMode=自动导出标记结果
\ No newline at end of file lockBox=锁定框/解除锁定框
lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态
...@@ -33,17 +33,17 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools ...@@ -33,17 +33,17 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
- [more](./doc/doc_en/update_en.md) - [more](./doc/doc_en/update_en.md)
## Features ## Features
- PP-OCR series of high-quality pre-trained models, comparable to commercial effects - PP-OCR - A series of high-quality pre-trained models, comparable to commercial products
- Ultra lightweight PP-OCRv2 series models: detection (3.1M) + direction classifier (1.4M) + recognition 8.5M) = 13.0M - Ultra lightweight PP-OCRv2 series models: detection (3.1M) + direction classifier (1.4M) + recognition 8.5M) = 13.0M
- Ultra lightweight PP-OCR mobile series models: detection (3.0M) + direction classifier (1.4M) + recognition (5.0M) = 9.4M - Ultra lightweight PP-OCR mobile series models: detection (3.0M) + direction classifier (1.4M) + recognition (5.0M) = 9.4M
- General PP-OCR server series models: detection (47.1M) + direction classifier (1.4M) + recognition (94.9M) = 143.4M - General PP-OCR server series models: detection (47.1M) + direction classifier (1.4M) + recognition (94.9M) = 143.4M
- Support Chinese, English, and digit recognition, vertical text recognition, and long text recognition - Support Chinese, English, and digit recognition, vertical text recognition, and long text recognition
- Support multi-language recognition: about 80 languages like Korean, Japanese, German, French, etc - Support multi-lingual recognition: about 80 languages like Korean, Japanese, German, French, etc
- PP-Structure: a document structurize system - PP-Structure: a document structurize system
- support layout analysis and table recognition (support export to Excel) - Support layout analysis and table recognition (support export to Excel)
- support key information extraction - Support key information extraction
- support DocVQA - Support DocVQA
- Rich toolkits related to the OCR areas - Rich OCR toolkit
- Semi-automatic data annotation tool, i.e., PPOCRLabel: support fast and efficient data annotation - Semi-automatic data annotation tool, i.e., PPOCRLabel: support fast and efficient data annotation
- Data synthesis tool, i.e., Style-Text: easy to synthesize a large number of images which are similar to the target scene image - Data synthesis tool, i.e., Style-Text: easy to synthesize a large number of images which are similar to the target scene image
- Support user-defined training, provides rich predictive inference deployment solutions - Support user-defined training, provides rich predictive inference deployment solutions
...@@ -62,7 +62,7 @@ The above pictures are the visualizations of the general ppocr_server model. For ...@@ -62,7 +62,7 @@ The above pictures are the visualizations of the general ppocr_server model. For
<a name="Community"></a> <a name="Community"></a>
## Community ## Community
- Scan the QR code below with your Wechat, you can access to official technical exchange group. Look forward to your participation. - Scan the QR code below with your Wechat, you can join the official technical discussion group. Looking forward to your participation.
<div align="center"> <div align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" /> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
...@@ -120,8 +120,8 @@ For a new language request, please refer to [Guideline for new language_requests ...@@ -120,8 +120,8 @@ For a new language request, please refer to [Guideline for new language_requests
- [PP-Structure: Information Extraction](./ppstructure/README.md) - [PP-Structure: Information Extraction](./ppstructure/README.md)
- [Layout Parser](./ppstructure/layout/README.md) - [Layout Parser](./ppstructure/layout/README.md)
- [Table Recognition](./ppstructure/table/README.md) - [Table Recognition](./ppstructure/table/README.md)
- [DocVQA](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.4/ppstructure/vqa) - [DocVQA](./ppstructure/vqa/README.md)
- [Key Information Extraction](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppstructure/docs/kie.md) - [Key Information Extraction](./ppstructure/docs/kie.md)
- Academic Circles - Academic Circles
- [Two-stage Algorithm](./doc/doc_en/algorithm_overview_en.md) - [Two-stage Algorithm](./doc/doc_en/algorithm_overview_en.md)
- [PGNet Algorithm](./doc/doc_en/pgnet_en.md) - [PGNet Algorithm](./doc/doc_en/pgnet_en.md)
......
...@@ -99,8 +99,8 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ...@@ -99,8 +99,8 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
- [PP-Structure信息提取](./ppstructure/README_ch.md) - [PP-Structure信息提取](./ppstructure/README_ch.md)
- [版面分析](./ppstructure/layout/README_ch.md) - [版面分析](./ppstructure/layout/README_ch.md)
- [表格识别](./ppstructure/table/README_ch.md) - [表格识别](./ppstructure/table/README_ch.md)
- [DocVQA](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.4/ppstructure/vqa) - [DocVQA](./ppstructure/vqa/README_ch.md)
- [关键信息提取](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppstructure/docs/kie.md) - [关键信息提取](./ppstructure/docs/kie.md)
- OCR学术圈 - OCR学术圈
- [两阶段模型介绍与下载](./doc/doc_ch/algorithm_overview.md) - [两阶段模型介绍与下载](./doc/doc_ch/algorithm_overview.md)
- [端到端PGNet算法](./doc/doc_ch/pgnet.md) - [端到端PGNet算法](./doc/doc_ch/pgnet.md)
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg
save_res_path: ./output/ser/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm/
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: doc/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -160,6 +160,7 @@ public class Predictor { ...@@ -160,6 +160,7 @@ public class Predictor {
for (String content : contents) { for (String content : contents) {
wordLabels.add(content); wordLabels.add(content);
} }
wordLabels.add(" ");
Log.i(TAG, "Word label size: " + wordLabels.size()); Log.i(TAG, "Word label size: " + wordLabels.size());
} catch (Exception e) { } catch (Exception e) {
Log.e(TAG, e.getMessage()); Log.e(TAG, e.getMessage());
......
# Server-side C++ Inference # Server-side C++ Inference
This chapter introduces the C++ deployment method of the PaddleOCR model, and the corresponding python predictive deployment method refers to [document](../../doc/doc_ch/inference.md). This chapter introduces the C++ deployment steps of the PaddleOCR model. The corresponding Python predictive deployment method refers to [document](../../doc/doc_ch/inference.md).
C++ is better than python in terms of performance calculation. Therefore, in most CPU and GPU deployment scenarios, C++ deployment is mostly used. C++ is better than python in terms of performance. Therefore, in CPU and GPU deployment scenarios, C++ deployment is mostly used.
This section will introduce how to configure the C++ environment and complete it in the Linux\Windows (CPU\GPU) environment This section will introduce how to configure the C++ environment and deploy PaddleOCR in Linux (CPU\GPU) environment. For Windows deployment please refer to [Windows](./docs/windows_vs2019_build.md) compilation guidelines.
PaddleOCR model deployment.
## 1. Prepare the Environment ## 1. Prepare the Environment
...@@ -15,7 +14,7 @@ PaddleOCR model deployment. ...@@ -15,7 +14,7 @@ PaddleOCR model deployment.
### 1.1 Compile OpenCV ### 1.1 Compile OpenCV
* First of all, you need to download the source code compiled package in the Linux environment from the opencv official website. Taking opencv3.4.7 as an example, the download command is as follows. * First of all, you need to download the source code compiled package in the Linux environment from the OpenCV official website. Taking OpenCV 3.4.7 as an example, the download command is as follows.
```bash ```bash
cd deploy/cpp_infer cd deploy/cpp_infer
...@@ -23,9 +22,9 @@ wget https://paddleocr.bj.bcebos.com/libs/opencv/opencv-3.4.7.tar.gz ...@@ -23,9 +22,9 @@ wget https://paddleocr.bj.bcebos.com/libs/opencv/opencv-3.4.7.tar.gz
tar -xf opencv-3.4.7.tar.gz tar -xf opencv-3.4.7.tar.gz
``` ```
Finally, you can see the folder of `opencv-3.4.7/` in the current directory. Finally, you will see the folder of `opencv-3.4.7/` in the current directory.
* Compile opencv, the opencv source path (`root_path`) and installation path (`install_path`) should be set by yourself. Enter the opencv source code path and compile it in the following way. * Compile OpenCV, the OpenCV source path (`root_path`) and installation path (`install_path`) should be set by yourself. Enter the OpenCV source code path and compile it in the following way.
```shell ```shell
...@@ -58,11 +57,11 @@ make -j ...@@ -58,11 +57,11 @@ make -j
make install make install
``` ```
Among them, `root_path` is the downloaded opencv source code path, and `install_path` is the installation path of opencv. After `make install` is completed, the opencv header file and library file will be generated in this folder for later OCR source code compilation. In the above commands, `root_path` is the downloaded OpenCV source code path, and `install_path` is the installation path of OpenCV. After `make install` is completed, the OpenCV header file and library file will be generated in this folder for later OCR source code compilation.
The final file structure under the opencv installation path is as follows. The final file structure under the OpenCV installation path is as follows.
``` ```
opencv3/ opencv3/
...@@ -79,20 +78,20 @@ opencv3/ ...@@ -79,20 +78,20 @@ opencv3/
#### 1.2.1 Direct download and installation #### 1.2.1 Direct download and installation
[Paddle inference library official website](https://paddle-inference.readthedocs.io/en/latest/user_guides/download_lib.html). You can view and select the appropriate version of the inference library on the official website. [Paddle inference library official website](https://paddle-inference.readthedocs.io/en/latest/user_guides/download_lib.html). You can review and select the appropriate version of the inference library on the official website.
* After downloading, use the following method to uncompress. * After downloading, use the following command to extract files.
``` ```
tar -xf paddle_inference.tgz tar -xf paddle_inference.tgz
``` ```
Finally you can see the following files in the folder of `paddle_inference/`. Finally you will see the the folder of `paddle_inference/` in the current path.
#### 1.2.2 Compile from the source code #### 1.2.2 Compile the inference source code
* If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code. It is recommended to download the inference library with paddle version greater than or equal to 2.0.1. * If you want to get the latest Paddle inference library features, you can download the latest code from Paddle GitHub repository and compile the inference library from the source code. It is recommended to download the inference library with paddle version greater than or equal to 2.0.1.
* You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows. * You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from GitHub, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
```shell ```shell
...@@ -100,7 +99,7 @@ git clone https://github.com/PaddlePaddle/Paddle.git ...@@ -100,7 +99,7 @@ git clone https://github.com/PaddlePaddle/Paddle.git
git checkout develop git checkout develop
``` ```
* After entering the Paddle directory, the commands to compile the paddle inference library are as follows. * Enter the Paddle directory and run the following commands to compile the paddle inference library.
```shell ```shell
rm -rf build rm -rf build
...@@ -133,14 +132,14 @@ build/paddle_inference_install_dir/ ...@@ -133,14 +132,14 @@ build/paddle_inference_install_dir/
|-- version.txt |-- version.txt
``` ```
Among them, `paddle` is the Paddle library required for C++ prediction later, and `version.txt` contains the version information of the current inference library. `paddle` is the Paddle library required for C++ prediction later, and `version.txt` contains the version information of the current inference library.
## 2. Compile and Run the Demo ## 2. Compile and Run the Demo
### 2.1 Export the inference model ### 2.1 Export the inference model
* You can refer to [Model inference](../../doc/doc_ch/inference.md)export the inference model. After the model is exported, assuming it is placed in the `inference` directory, the directory structure is as follows. * You can refer to [Model inference](../../doc/doc_ch/inference.md) and export the inference model. After the model is exported, assuming it is placed in the `inference` directory, the directory structure is as follows.
``` ```
inference/ inference/
...@@ -171,20 +170,28 @@ CUDA_LIB_DIR=your_cuda_lib_dir ...@@ -171,20 +170,28 @@ CUDA_LIB_DIR=your_cuda_lib_dir
CUDNN_LIB_DIR=your_cudnn_lib_dir CUDNN_LIB_DIR=your_cudnn_lib_dir
``` ```
`OPENCV_DIR` is the opencv installation path; `LIB_DIR` is the download (`paddle_inference` folder) `OPENCV_DIR` is the OpenCV installation path; `LIB_DIR` is the download (`paddle_inference` folder)
or the generated Paddle inference library path (`build/paddle_inference_install_dir` folder); or the generated Paddle inference library path (`build/paddle_inference_install_dir` folder);
`CUDA_LIB_DIR` is the cuda library file path, in docker; it is `/usr/local/cuda/lib64`; `CUDNN_LIB_DIR` is the cudnn library file path, in docker it is `/usr/lib/x86_64-linux-gnu/`. `CUDA_LIB_DIR` is the CUDA library file path, in docker; it is `/usr/local/cuda/lib64`; `CUDNN_LIB_DIR` is the cuDNN library file path, in docker it is `/usr/lib/x86_64-linux-gnu/`.
* After the compilation is completed, an executable file named `ppocr` will be generated in the `build` folder. * After the compilation is completed, an executable file named `ppocr` will be generated in the `build` folder.
### Run the demo ### Run the demo
Execute the built executable file: Execute the built executable file:
```shell ```shell
./build/ppocr <mode> [--param1] [--param2] [...] ./build/ppocr <mode> [--param1] [--param2] [...]
``` ```
Here, `mode` is a required parameter,and the value range is ['det', 'rec', 'system'], representing using detection only, using recognition only and using the end-to-end system respectively. Specifically, `mode` is a required parameter,and the valid values are
mode value | Model used
-----|------
det | Detection only
rec | Recognition only
system | End-to-end system
Specifically,
##### 1. run det demo: ##### 1. run det demo:
```shell ```shell
...@@ -214,9 +221,9 @@ Here, `mode` is a required parameter,and the value range is ['det', 'rec', 'sy ...@@ -214,9 +221,9 @@ Here, `mode` is a required parameter,and the value range is ['det', 'rec', 'sy
--image_dir=../../doc/imgs/12.jpg --image_dir=../../doc/imgs/12.jpg
``` ```
More parameters are as follows, More parameters are as follows,
- common parameters - Common parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
| --- | --- | --- | --- | | --- | --- | --- | --- |
...@@ -226,7 +233,7 @@ More parameters are as follows, ...@@ -226,7 +233,7 @@ More parameters are as follows,
|cpu_math_library_num_threads|int|10|Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed| |cpu_math_library_num_threads|int|10|Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed|
|use_mkldnn|bool|true|Whether to use mkdlnn library| |use_mkldnn|bool|true|Whether to use mkdlnn library|
- detection related parameters - Detection related parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
| --- | --- | --- | --- | | --- | --- | --- | --- |
...@@ -238,7 +245,7 @@ More parameters are as follows, ...@@ -238,7 +245,7 @@ More parameters are as follows,
|use_polygon_score|bool|false|Whether to use polygon box to calculate bbox score, false means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.| |use_polygon_score|bool|false|Whether to use polygon box to calculate bbox score, false means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.|
|visualize|bool|true|Whether to visualize the results,when it is set as true, The prediction result will be save in the image file `./ocr_vis.png`.| |visualize|bool|true|Whether to visualize the results,when it is set as true, The prediction result will be save in the image file `./ocr_vis.png`.|
- classifier related parameters - Classifier related parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
| --- | --- | --- | --- | | --- | --- | --- | --- |
...@@ -246,7 +253,7 @@ More parameters are as follows, ...@@ -246,7 +253,7 @@ More parameters are as follows,
|cls_model_dir|string|-|Address of direction classifier inference model| |cls_model_dir|string|-|Address of direction classifier inference model|
|cls_thresh|float|0.9|Score threshold of the direction classifier| |cls_thresh|float|0.9|Score threshold of the direction classifier|
- recogniton related parameters - Recognition related parameters
|parameter|data type|default|meaning| |parameter|data type|default|meaning|
| --- | --- | --- | --- | | --- | --- | --- | --- |
...@@ -265,4 +272,4 @@ The detection results will be shown on the screen, which is as follows. ...@@ -265,4 +272,4 @@ The detection results will be shown on the screen, which is as follows.
### 2.3 Notes ### 2.3 Notes
* Paddle2.0.0 inference model library is recommended for this toturial. * Paddle 2.0.0 inference model library is recommended for this tutorial.
English | [简体中文](README_cn.md) English | [简体中文](README_cn.md)
## Introduction ## Introduction
Many users hope package the PaddleOCR service into a docker image, so that it can be quickly released and used in the docker or k8s environment. Many users hope package the PaddleOCR service into a docker image, so that it can be quickly released and used in the docker or K8s environment.
This page provides some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue) This page provides some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the future)
## 1. Prerequisites ## 1. Prerequisites
...@@ -14,7 +14,7 @@ c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this) ...@@ -14,7 +14,7 @@ c. NVIDIA Container Toolkit(GPU,Docker 19.03+ can skip this)
d. cuDNN 7.6+(GPU) d. cuDNN 7.6+(GPU)
## 2. Build Image ## 2. Build Image
a. Goto Dockerfile directory(ps:Need to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword) a. Go to Dockerfile directory(PS: Need to distinguish between CPU and GPU version, the following takes CPU as an example, GPU version needs to replace the keyword)
``` ```
cd deploy/docker/hubserving/cpu cd deploy/docker/hubserving/cpu
``` ```
...@@ -42,13 +42,13 @@ docker logs -f paddle_ocr ...@@ -42,13 +42,13 @@ docker logs -f paddle_ocr
``` ```
## 4. Test ## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, like:https://freeonlinetools24.com/base64-image/) a. Calculate the Base64 encoding of the picture to be recognized (For test purpose, you can use a free online tool such as https://freeonlinetools24.com/base64-image/ )
b. Post a service request(sample request in sample_request.txt) b. Post a service request(sample request in sample_request.txt)
``` ```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8868/predict/ocr_system curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,')\"]}" http://localhost:8868/predict/ocr_system
``` ```
c. Get resposne(If the call is successful, the following result will be returned) c. Get response(If the call is successful, the following result will be returned)
``` ```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"} {"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
``` ```
# Tutorial of PaddleOCR Mobile deployment # Tutorial of PaddleOCR Mobile deployment
This tutorial will introduce how to use [Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) to deploy paddleOCR ultra-lightweight Chinese and English detection models on mobile phones. This tutorial will introduce how to use [Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) to deploy PaddleOCR ultra-lightweight Chinese and English detection models on mobile phones.
paddle-lite is a lightweight inference engine for PaddlePaddle. It provides efficient inference capabilities for mobile phones and IoTs, and extensively integrates cross-platform hardware to provide lightweight deployment solutions for end-side deployment issues. paddle-lite is a lightweight inference engine for PaddlePaddle. It provides efficient inference capabilities for mobile phones and IoT, and extensively integrates cross-platform hardware to provide lightweight deployment solutions for end-side deployment issues.
## 1. Preparation ## 1. Preparation
......
...@@ -22,6 +22,7 @@ PaddleOCR提供2种服务部署方式: ...@@ -22,6 +22,7 @@ PaddleOCR提供2种服务部署方式:
- [环境准备](#环境准备) - [环境准备](#环境准备)
- [模型转换](#模型转换) - [模型转换](#模型转换)
- [Paddle Serving pipeline部署](#部署) - [Paddle Serving pipeline部署](#部署)
- [Windows用户](#Windows用户)
- [FAQ](#FAQ) - [FAQ](#FAQ)
<a name="环境准备"></a> <a name="环境准备"></a>
...@@ -187,9 +188,10 @@ python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \ ...@@ -187,9 +188,10 @@ python3 -m paddle_serving_client.convert --dirname ./ch_PP-OCRv2_rec_infer/ \
2021-05-13 03:42:36,979 chl2(In: ['rec'], Out: ['@DAGExecutor']) size[0/0] 2021-05-13 03:42:36,979 chl2(In: ['rec'], Out: ['@DAGExecutor']) size[0/0]
``` ```
## WINDOWS用户 <a name="Windows用户"></a>
## Windows用户
Windows用户不能使用上述的启动方式,需要使用Web Service,详情参见[Windows平台使用Paddle Serving指导](https://github.com/PaddlePaddle/Serving/blob/develop/doc/WINDOWS_TUTORIAL_CN.md) Windows用户不能使用上述的启动方式,需要使用Web Service,详情参见[Windows平台使用Paddle Serving指导](https://github.com/PaddlePaddle/Serving/blob/develop/doc/Windows_Tutorial_CN.md)
**WINDOWS只能使用0.5.0版本的CPU模式** **WINDOWS只能使用0.5.0版本的CPU模式**
......
...@@ -28,14 +28,14 @@ python3 setup.py install ...@@ -28,14 +28,14 @@ python3 setup.py install
``` ```
### 2. Download Pretrain Model ### 2. Download Pre-trained Model
Model prune needs to load pre-trained models. Model prune needs to load pre-trained models.
PaddleOCR also provides a series of [models](../../../doc/doc_en/models_list_en.md). Developers can choose their own models or use their own models according to their needs. PaddleOCR also provides a series of [models](../../../doc/doc_en/models_list_en.md). Developers can choose their own models or use their own models according to their needs.
### 3. Pruning sensitivity analysis ### 3. Pruning sensitivity analysis
After the pre-training model is loaded, sensitivity analysis is performed on each network layer of the model to understand the redundancy of each network layer, and save a sensitivity file which named: sen.pickle. After that, user could load the sensitivity file via the [methods provided by PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221) and determining the pruning ratio of each network layer automatically. For specific details of sensitivity analysis, see:[Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/image_classification_sensitivity_analysis_tutorial.md) After the pre-trained model is loaded, sensitivity analysis is performed on each network layer of the model to understand the redundancy of each network layer, and save a sensitivity file which named: sen.pickle. After that, user could load the sensitivity file via the [methods provided by PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221) and determining the pruning ratio of each network layer automatically. For specific details of sensitivity analysis, see:[Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/image_classification_sensitivity_analysis_tutorial.md)
The data format of sensitivity file: The data format of sensitivity file:
sen.pickle(Dict){ sen.pickle(Dict){
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss} 'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
...@@ -47,7 +47,7 @@ PaddleOCR also provides a series of [models](../../../doc/doc_en/models_list_en. ...@@ -47,7 +47,7 @@ PaddleOCR also provides a series of [models](../../../doc/doc_en/models_list_en.
'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594} 'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405} 'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
} }
The function would return a dict after loading the sensitivity file. The keys of the dict are name of parameters in each layer. And the value of key is the information about pruning sensitivity of correspoding layer. In example, pruning 10% filter of the layer corresponding to conv10_expand_weights would lead to 0.65% degradation of model performance. The details could be seen at: [Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86) The function would return a dict after loading the sensitivity file. The keys of the dict are name of parameters in each layer. And the value of key is the information about pruning sensitivity of corresponding layer. In example, pruning 10% filter of the layer corresponding to conv10_expand_weights would lead to 0.65% degradation of model performance. The details could be seen at: [Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
Enter the PaddleOCR root directory,perform sensitivity analysis on the model with the following command: Enter the PaddleOCR root directory,perform sensitivity analysis on the model with the following command:
......
## Introduction ## Introduction
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model. Generally, a more complex model would achieve better performance in the task, but it also leads to some redundancy in the model.
Quantization is a technique that reduces this redundancy by reducing the full precision data to a fixed number, Quantization is a technique that reduces this redundancy by reducing the full precision data to a fixed number,
so as to reduce model calculation complexity and improve model inference performance. so as to reduce model calculation complexity and improve model inference performance.
...@@ -31,14 +31,14 @@ python setup.py install ...@@ -31,14 +31,14 @@ python setup.py install
``` ```
### 2. Download Pretrain Model ### 2. Download Pre-trained Model
PaddleOCR provides a series of trained [models](../../../doc/doc_en/models_list_en.md). PaddleOCR provides a series of pre-trained [models](../../../doc/doc_en/models_list_en.md).
If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../doc/doc_en/quickstart_en.md) method to get the trained model. If the model to be quantified is not in the list, you need to follow the [Regular Training](../../../doc/doc_en/quickstart_en.md) method to get the trained model.
### 3. Quant-Aware Training ### 3. Quant-Aware Training
Quantization training includes offline quantization training and online quantization training. Quantization training includes offline quantization training and online quantization training.
Online quantization training is more effective. It is necessary to load the pre-training model. Online quantization training is more effective. It is necessary to load the pre-trained model.
After the quantization strategy is defined, the model can be quantified. After the quantization strategy is defined, the model can be quantified.
The code for quantization training is located in `slim/quantization/quant.py`. For example, to train a detection model, the training instructions are as follows: The code for quantization training is located in `slim/quantization/quant.py`. For example, to train a detection model, the training instructions are as follows:
...@@ -54,7 +54,7 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3 ...@@ -54,7 +54,7 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3
### 4. Export inference model ### 4. Export inference model
After getting the model after pruning and finetuning we, can export it as inference_model for predictive deployment: Once we got the model after pruning and fine-tuning, we can export it as an inference model for the deployment of predictive tasks:
```bash ```bash
python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model
......
...@@ -76,7 +76,7 @@ def main(): ...@@ -76,7 +76,7 @@ def main():
} }
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) config = merge_config(config, FLAGS.opt)
logger = get_logger() logger = get_logger()
# build post process # build post process
......
...@@ -25,8 +25,8 @@ PaddleOCR开源的文本检测算法列表: ...@@ -25,8 +25,8 @@ PaddleOCR开源的文本检测算法列表:
在ICDAR2015文本检测公开数据集上,算法效果如下: 在ICDAR2015文本检测公开数据集上,算法效果如下:
|模型|骨干网络|precision|recall|Hmean|下载链接| |模型|骨干网络|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)| |EAST|ResNet50_vd|88.71%|81.36%|84.88%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)| |EAST|MobileNetV3|78.2%|79.1%|78.65%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|86.41%|78.72%|82.38%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
...@@ -61,18 +61,18 @@ PaddleOCR基于动态图开源的文本识别算法列表: ...@@ -61,18 +61,18 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|---|---|---|---|---| |---|---|---|---|---|
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)| |Rosetta|Resnet34_vd|79.11%|rec_r34_vd_none_none_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| |Rosetta|MobileNetV3|75.80%|rec_mv3_none_none_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)| |CRNN|Resnet34_vd|81.04%|rec_r34_vd_none_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |CRNN|MobileNetV3|77.95%|rec_mv3_none_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|82.85%|rec_r34_vd_tps_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|79.28%|rec_mv3_tps_bilstm_ctc|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.98%|rec_r34_vd_tps_bilstm_att |[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |RARE|MobileNetV3|81.76%|rec_mv3_tps_bilstm_att |[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | |SRN|Resnet50_vd_fpn| 86.31% | rec_r50fpn_vd_none_srn | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | |NRTR|NRTR_MTB| 84.21% | rec_mtb_nrtr | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|SAR|Resnet31| 87.2% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.2% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
<a name="2"></a> <a name="2"></a>
......
...@@ -14,12 +14,12 @@ Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编 ...@@ -14,12 +14,12 @@ Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编
1. Start a new Android Studio project 1. Start a new Android Studio project
在项目模版中选择 Native C++ 选择PaddleOCR/depoly/android_demo 路径 在项目模版中选择 Native C++ 选择PaddleOCR/deploy/android_demo 路径
进入项目后会自动编译,第一次编译会花费较长的时间,建议添加代理加速下载。 进入项目后会自动编译,第一次编译会花费较长的时间,建议添加代理加速下载。
**代理添加:** **代理添加:**
选择 Android Studio -> Perferences -> Appearance & Behavior -> System Settings -> HTTP Proxy -> Manual proxy configuration 选择 Android Studio -> Preferences -> Appearance & Behavior -> System Settings -> HTTP Proxy -> Manual proxy configuration
![](../demo/proxy.png) ![](../demo/proxy.png)
......
...@@ -16,7 +16,7 @@ PaddleOCR的Python代码遵循 [PEP8规范](https://www.python.org/dev/peps/pep- ...@@ -16,7 +16,7 @@ PaddleOCR的Python代码遵循 [PEP8规范](https://www.python.org/dev/peps/pep-
- 空格 - 空格
- 空格应该加在逗号、分号、冒号前,而非他们的后 - 空格应该加在逗号、分号、冒号后,而非他们的前
```python ```python
# 正确: # 正确:
...@@ -334,4 +334,4 @@ git push origin new_branch ...@@ -334,4 +334,4 @@ git push origin new_branch
2)如果评审意见比较多: 2)如果评审意见比较多:
- 请给出总体的修改情况。 - 请给出总体的修改情况。
- 请采用`start a review`进行回复,而非直接回复的方式。原因是每个回复都会发送一封邮件,会造成邮件灾难。 - 请采用`start a review`进行回复,而非直接回复的方式。原因是每个回复都会发送一封邮件,会造成邮件灾难。
\ No newline at end of file
...@@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -78,11 +78,11 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
cd PaddleOCR/ cd PaddleOCR/
# 根据backbone的不同选择下载对应的预训练模型 # 根据backbone的不同选择下载对应的预训练模型
# 下载MobileNetV3的预训练模型 # 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
# 或,下载ResNet18_vd的预训练模型 # 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
# 或,下载ResNet50_vd的预训练模型 # 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
``` ```
<a name="2-----"></a> <a name="2-----"></a>
......
<a name="0"></a>
# 知识蒸馏 # 知识蒸馏
+ [知识蒸馏](#0)
+ [1. 简介](#1)
- [1.1 知识蒸馏介绍](#11)
- [1.2 PaddleOCR知识蒸馏简介](#12)
+ [2. 配置文件解析](#2)
+ [2.1 识别配置文件解析](#21)
- [2.1.1 模型结构](#211)
- [2.1.2 损失函数](#212)
- [2.1.3 后处理](#213)
- [2.1.4 指标计算](#214)
- [2.1.5 蒸馏模型微调](#215)
+ [2.2 检测配置文件解析](#22)
- [2.2.1 模型结构](#221)
- [2.2.2 损失函数](#222)
- [2.2.3 后处理](#223)
- [2.2.4 蒸馏指标计算](#224)
- [2.2.5 检测蒸馏模型Fine-tune](#225)
<a name="1"></a>
## 1. 简介 ## 1. 简介
<a name="11"></a>
### 1.1 知识蒸馏介绍 ### 1.1 知识蒸馏介绍
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。 近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
...@@ -13,11 +32,12 @@ ...@@ -13,11 +32,12 @@
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。 此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
<a name="12"></a>
### 1.2 PaddleOCR知识蒸馏简介 ### 1.2 PaddleOCR知识蒸馏简介
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。 无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。 对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。 在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
...@@ -30,17 +50,19 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要 ...@@ -30,17 +50,19 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。 通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
<a name="2"></a>
## 2. 配置文件解析 ## 2. 配置文件解析
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。 在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。 下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
<a name="21"></a>
### 2.1 识别配置文件解析 ### 2.1 识别配置文件解析
配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml) 配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)
<a name="211"></a>
#### 2.1.1 模型结构 #### 2.1.1 模型结构
知识蒸馏任务中,模型结构配置如下所示。 知识蒸馏任务中,模型结构配置如下所示。
...@@ -176,6 +198,7 @@ Architecture: ...@@ -176,6 +198,7 @@ Architecture:
} }
``` ```
<a name="212"></a>
#### 2.1.2 损失函数 #### 2.1.2 损失函数
知识蒸馏任务中,损失函数配置如下所示。 知识蒸馏任务中,损失函数配置如下所示。
...@@ -212,7 +235,7 @@ Loss: ...@@ -212,7 +235,7 @@ Loss:
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py) 关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)
<a name="213"></a>
#### 2.1.3 后处理 #### 2.1.3 后处理
知识蒸馏任务中,后处理配置如下所示。 知识蒸馏任务中,后处理配置如下所示。
...@@ -228,7 +251,7 @@ PostProcess: ...@@ -228,7 +251,7 @@ PostProcess:
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128) 关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
<a name="214"></a>
#### 2.1.4 指标计算 #### 2.1.4 指标计算
知识蒸馏任务中,指标计算配置如下所示。 知识蒸馏任务中,指标计算配置如下所示。
...@@ -245,7 +268,7 @@ Metric: ...@@ -245,7 +268,7 @@ Metric:
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24) 关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)
<a name="215"></a>
#### 2.1.5 蒸馏模型微调 #### 2.1.5 蒸馏模型微调
对蒸馏得到的识别蒸馏进行微调有2种方式。 对蒸馏得到的识别蒸馏进行微调有2种方式。
...@@ -279,15 +302,15 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams") ...@@ -279,15 +302,15 @@ paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。 转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
<a name="22"></a>
### 2.2 检测配置文件解析 ### 2.2 检测配置文件解析
检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件: 检测模型蒸馏的配置文件在PaddleOCR/configs/det/ch_PP-OCRv2/目录下,包含三个蒸馏配置文件:
- ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法 - ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,采用一个大模型蒸馏两个小模型,且两个小模型互相学习的方法
- ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法 - ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法
- ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法 - ch_PP-OCRv2_det_distill.yml,采用Teacher大模型蒸馏小模型Student的方法
<a name="221"></a>
#### 2.2.1 模型结构 #### 2.2.1 模型结构
知识蒸馏任务中,模型结构配置如下所示: 知识蒸馏任务中,模型结构配置如下所示:
...@@ -419,7 +442,8 @@ Architecture: ...@@ -419,7 +442,8 @@ Architecture:
} }
``` ```
#### 2.1.2 损失函数 <a name="222"></a>
#### 2.2.2 损失函数
知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。 知识蒸馏任务中,检测ch_PP-OCRv2_det_distill.yml蒸馏损失函数配置如下所示。
...@@ -484,8 +508,8 @@ Loss: ...@@ -484,8 +508,8 @@ Loss:
关于`DistillationDilaDBLoss`更加具体的实现可以参考: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185)。关于`DistillationDBLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148) 关于`DistillationDilaDBLoss`更加具体的实现可以参考: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185)。关于`DistillationDBLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148)
<a name="223"></a>
#### 2.1.3 后处理 #### 2.2.3 后处理
知识蒸馏任务中,检测蒸馏后处理配置如下所示。 知识蒸馏任务中,检测蒸馏后处理配置如下所示。
...@@ -503,8 +527,8 @@ PostProcess: ...@@ -503,8 +527,8 @@ PostProcess:
关于`DistillationDBPostProcess`更加具体的实现可以参考: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195) 关于`DistillationDBPostProcess`更加具体的实现可以参考: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195)
<a name="224"></a>
#### 2.1.4 蒸馏指标计算 #### 2.2.4 蒸馏指标计算
知识蒸馏任务中,检测蒸馏指标计算配置如下所示。 知识蒸馏任务中,检测蒸馏指标计算配置如下所示。
...@@ -518,15 +542,15 @@ Metric: ...@@ -518,15 +542,15 @@ Metric:
由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,`key`字段设置为`Student`则表示只计算`Student`网络的精度。 由于蒸馏需要包含多个网络,甚至多个Student网络,在计算指标的时候只需要计算一个Student网络的指标即可,`key`字段设置为`Student`则表示只计算`Student`网络的精度。
<a name="225"></a>
#### 2.1.5 检测蒸馏模型finetune #### 2.2.5 检测蒸馏模型finetune
检测蒸馏有三种方式: 检测蒸馏有三种方式:
- 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型 - 采用ch_PP-OCRv2_det_distill.yml,Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型 - 采用ch_PP-OCRv2_det_cml.yml,采用cml蒸馏,同样Teacher模型设置为PaddleOCR提供的模型或者您训练好的大模型
- 采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。 - 采用ch_PP-OCRv2_det_dml.yml,采用DML的蒸馏,两个Student模型互蒸馏的方法,在PaddleOCR采用的数据集上大约有1.7%的精度提升。
在具体finetune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。 在具体fine-tune时,需要在网络结构的`pretrained`参数中设置要加载的预训练模型。
在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。 在精度提升方面,cml的精度>dml的精度>distill蒸馏方法的精度。当数据量不足或者Teacher模型精度与Student精度相差不大的时候,这个结论或许会改变。
......
...@@ -63,6 +63,17 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 ...@@ -63,6 +63,17 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
| ... | ...
``` ```
除上述单张图像为一行格式之外,PaddleOCR也支持对离线增广后的数据进行训练,为了防止相同样本在同一个batch中被多次采样,我们可以将相同标签对应的图片路径写在一行中,以列表的形式给出,在训练中,PaddleOCR会随机选择列表中的一张图片进行训练。对应地,标注文件的格式如下。
```
["11.jpg", "12.jpg"] 简单可依赖
["21.jpg", "22.jpg", "23.jpg"] 用科技让复杂的世界更简单
3.jpg ocr
```
上述示例标注文件中,"11.jpg"和"12.jpg"的标签相同,都是`简单可依赖`,在训练的时候,对于该行标注,会随机选择其中的一张图片进行训练。
- 测试集 - 测试集
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示: 同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示:
......
...@@ -60,9 +60,9 @@ PaddleOCR非常欢迎社区贡献以PaddleOCR为核心的各种服务、部署 ...@@ -60,9 +60,9 @@ PaddleOCR非常欢迎社区贡献以PaddleOCR为核心的各种服务、部署
如果您在使用PaddleOCR时遇到了代码bug、功能不符合预期等问题,可以为PaddleOCR贡献您的修改,其中: 如果您在使用PaddleOCR时遇到了代码bug、功能不符合预期等问题,可以为PaddleOCR贡献您的修改,其中:
- Python代码规范可参考[附录1:Python代码规范](./code_and_doc.md/#附录1) - Python代码规范可参考[附录1:Python代码规范](./code_and_doc.md#附录1)
- 提交代码前请再三确认不会引入新的bug,并在PR中描述优化点。如果该PR解决了某个issue,请在PR中连接到该issue。所有的PR都应该遵守附录3中的[3.2.10 提交代码的一些约定。](./code_and_doc.md/#提交代码的一些约定) - 提交代码前请再三确认不会引入新的bug,并在PR中描述优化点。如果该PR解决了某个issue,请在PR中连接到该issue。所有的PR都应该遵守附录3中的[3.2.10 提交代码的一些约定。](./code_and_doc.md#提交代码的一些约定)
- 请在提交之前参考下方的[附录3:Pull Request说明](./code_and_doc.md#附录3)。如果您对git的提交流程不熟悉,同样可以参考附录3的3.2节。 - 请在提交之前参考下方的[附录3:Pull Request说明](./code_and_doc.md#附录3)。如果您对git的提交流程不熟悉,同样可以参考附录3的3.2节。
...@@ -70,7 +70,7 @@ PaddleOCR非常欢迎社区贡献以PaddleOCR为核心的各种服务、部署 ...@@ -70,7 +70,7 @@ PaddleOCR非常欢迎社区贡献以PaddleOCR为核心的各种服务、部署
### 2.3 文档优化 ### 2.3 文档优化
如果您在使用PaddleOCR时遇到了文档表述不清楚、描述缺失、链接失效等问题,可以为PaddleOCR贡献您的修改。文档书写规范请参考[附录2:文档规范](./code_and_doc.md/#附录2)**最后请在PR的题目中加上标签`【third-party】` , 在说明中@Evezerest,拥有此标签的PR将会被高优处理。** 如果您在使用PaddleOCR时遇到了文档表述不清楚、描述缺失、链接失效等问题,可以为PaddleOCR贡献您的修改。文档书写规范请参考[附录2:文档规范](./code_and_doc.md#附录2)**最后请在PR的题目中加上标签`【third-party】` , 在说明中@Evezerest,拥有此标签的PR将会被高优处理。**
## 3. 更多贡献机会 ## 3. 更多贡献机会
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
- 2020.12.07 [FAQ](../../doc/doc_ch/FAQ.md)新增5个高频问题,总数124个,并且计划以后每周一都会更新,欢迎大家持续关注。 - 2020.12.07 [FAQ](../../doc/doc_ch/FAQ.md)新增5个高频问题,总数124个,并且计划以后每周一都会更新,欢迎大家持续关注。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](../../PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。 - 2020.11.25 更新半自动标注工具[PPOCRLabel](../../PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941 - 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
- 2020.9.19 更新超轻量压缩ppocr_mobile_slim系列模型,整体模型3.5M(详见PP-OCR Pipline),适合在移动端部署使用。 - 2020.9.19 更新超轻量压缩ppocr_mobile_slim系列模型,整体模型3.5M(详见PP-OCR Pipeline),适合在移动端部署使用。
- 2020.9.17 更新超轻量ppocr_mobile系列和通用ppocr_server系列中英文ocr模型,媲美商业效果。 - 2020.9.17 更新超轻量ppocr_mobile系列和通用ppocr_server系列中英文ocr模型,媲美商业效果。
- 2020.9.17 更新[英文识别模型](./models_list.md#english-recognition-model)[多语种识别模型](./models_list.md#english-recognition-model),已支持`德语、法语、日语、韩语`,更多语种识别模型将持续更新。 - 2020.9.17 更新[英文识别模型](./models_list.md#english-recognition-model)[多语种识别模型](./models_list.md#english-recognition-model),已支持`德语、法语、日语、韩语`,更多语种识别模型将持续更新。
- 2020.8.26 更新OCR相关的84个常见问题及解答,具体参考[FAQ](./FAQ.md) - 2020.8.26 更新OCR相关的84个常见问题及解答,具体参考[FAQ](./FAQ.md)
......
## FAQ ## FAQ
1. **Prediction error: got an unexpected keyword argument 'gradient_clip'** 1. **Prediction error: got an unexpected keyword argument 'gradient_clip'**
The installed version of paddle is incorrect. Currently, this project only supports paddle1.7, which will be adapted to 1.8 in the near future. The installed version of paddle is incorrect. Currently, this project only supports Paddle 1.7, which will be adapted to 1.8 in the near future.
2. **Error when converting attention recognition model: KeyError: 'predict'** 2. **Error when converting attention recognition model: KeyError: 'predict'**
Solved. Please update to the latest version of the code. Solved. Please update to the latest version of the code.
...@@ -31,7 +31,7 @@ At present, PaddleOCR has opensourced two Chinese models, namely 8.6M ultra-ligh ...@@ -31,7 +31,7 @@ At present, PaddleOCR has opensourced two Chinese models, namely 8.6M ultra-ligh
|General Chinese OCR model|Resnet50_vd+Resnet34_vd|det_r50_vd_db.yml|rec_chinese_common_train.yml| |General Chinese OCR model|Resnet50_vd+Resnet34_vd|det_r50_vd_db.yml|rec_chinese_common_train.yml|
8. **Is there a plan to opensource a model that only recognizes numbers or only English + numbers?** 8. **Is there a plan to opensource a model that only recognizes numbers or only English + numbers?**
It is not planned to opensource numbers only, numbers + English only, or other vertical text models. Paddleocr has opensourced a variety of detection and recognition algorithms for customized training. The two Chinese models are also based on the training output of the open-source algorithm library. You can prepare the data according to the tutorial, choose the appropriate configuration file, train yourselves, and we believe that you can get good result. If you have any questions during the training, you are welcome to open issues or ask in the communication group. We will answer them in time. It is not planned to opensource numbers only, numbers + English only, or other vertical text models. PaddleOCR has opensourced a variety of detection and recognition algorithms for customized training. The two Chinese models are also based on the training output of the open-source algorithm library. You can prepare the data according to the tutorial, choose the appropriate configuration file, train yourselves, and we believe that you can get good result. If you have any questions during the training, you are welcome to open issues or ask in the communication group. We will answer them in time.
9. **What is the training data used by the open-source model? Can it be opensourced?** 9. **What is the training data used by the open-source model? Can it be opensourced?**
At present, the open source model, dataset and magnitude are as follows: At present, the open source model, dataset and magnitude are as follows:
...@@ -46,11 +46,11 @@ At present, the open source model, dataset and magnitude are as follows: ...@@ -46,11 +46,11 @@ At present, the open source model, dataset and magnitude are as follows:
10. **Error in using the model with TPS module for prediction** 10. **Error in using the model with TPS module for prediction**
Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3]\(108) != Grid dimension[2]\(100) Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3]\(108) != Grid dimension[2]\(100)
SolutionTPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en' Solution: TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en'
11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary** 11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary**
The used custom dictionary path is not set when making prediction. The solution is setting parameter `rec_char_dict_path` to the corresponding dictionary file. The used custom dictionary path is not set when making prediction. The solution is setting parameter `rec_char_dict_path` to the corresponding dictionary file.
12. **Results of cpp_infer and python_inference are very different** 12. **Results of cpp_infer and python_inference are very different**
Versions of exprted inference model and inference libraray should be same. For example, on Windows platform, version of the inference libraray that PaddlePaddle provides is 1.8, but version of the inference model that PaddleOCR provides is 1.7, you should export model yourself(`tools/export_model.py`) on PaddlePaddle1.8 and then use the exported model for inference. Versions of exported inference model and inference library should be same. For example, on Windows platform, version of the inference library that PaddlePaddle provides is 1.8, but version of the inference model that PaddleOCR provides is 1.7, you should export model yourself(`tools/export_model.py`) on PaddlePaddle 1.8 and then use the exported model for inference.
...@@ -30,8 +30,8 @@ On the ICDAR2015 dataset, the text detection result is as follows: ...@@ -30,8 +30,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|Model|Backbone|Precision|Recall|Hmean|Download link| |Model|Backbone|Precision|Recall|Hmean|Download link|
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)| |EAST|ResNet50_vd|88.71%|81.36%|84.88%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)| |EAST|MobileNetV3|78.2%|79.1%|78.65%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)| |DB|ResNet50_vd|86.41%|78.72%|82.38%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)| |DB|MobileNetV3|77.29%|73.08%|75.12%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)| |SAST|ResNet50_vd|91.39%|83.77%|87.42%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
...@@ -67,20 +67,20 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -67,20 +67,20 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|Model|Backbone|Avg Accuracy|Module combination|Download link| |Model|Backbone|Avg Accuracy|Module combination|Download link|
|---|---|---|---|---| |---|---|---|---|---|
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)| |Rosetta|Resnet34_vd|79.11%|rec_r34_vd_none_none_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)| |Rosetta|MobileNetV3|75.80%|rec_mv3_none_none_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)| |CRNN|Resnet34_vd|81.04%|rec_r34_vd_none_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |CRNN|MobileNetV3|77.95%|rec_mv3_none_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|82.85%|rec_r34_vd_tps_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|79.28%|rec_mv3_tps_bilstm_ctc|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)| |RARE|Resnet34_vd|83.98%|rec_r34_vd_tps_bilstm_att |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)| |RARE|MobileNetV3|81.76%|rec_mv3_tps_bilstm_att |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| |SRN|Resnet50_vd_fpn| 86.31% | rec_r50fpn_vd_none_srn |[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) | |NRTR|NRTR_MTB| 84.21% | rec_mtb_nrtr | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|SAR|Resnet31| 87.2% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) | |SAR|Resnet31| 87.20% | rec_r31_sar | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|SEED|Aster_Resnet| 85.2% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) | |SEED|Aster_Resnet| 85.35% | rec_resnet_stn_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar) |
Please refer to the document for training guide and use of PaddleOCR Please refer to the document for training guide and use of PaddleOCR
## 2. Training ## 2. Training
......
...@@ -20,7 +20,7 @@ File -> New ->New Project to create "Native C++" project ...@@ -20,7 +20,7 @@ File -> New ->New Project to create "Native C++" project
**Agent add:** **Agent add:**
Android Studio -> Perferences -> Appearance & Behavior -> System Settings -> HTTP Proxy -> Manual proxy configuration Android Studio -> Preferences -> Appearance & Behavior -> System Settings -> HTTP Proxy -> Manual proxy configuration
![](../demo/proxy.png) ![](../demo/proxy.png)
......
...@@ -92,7 +92,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c ...@@ -92,7 +92,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c
PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, Please uncomment the `RecAug` and `RandAugment` fields under `Train.dataset.transforms` in the configuration file. PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, Please uncomment the `RecAug` and `RandAugment` fields under `Train.dataset.transforms` in the configuration file.
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment. The default perturbation methods are: cvtColor, blur, jitter, Gauss noise, random crop, perspective, color reverse, RandAugment.
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py) [rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
......
- Appendix
This appendix contains python, document specifications and Pull Request process. Please follow the relevant contents
- [Appendix 1:Python Code Specification](#Appendix1)
- [Appendix 2:Document Specification](#Appendix2)
- [Appendix 3:Pull Request Description](#Appendix3)
<a name="Appendix1"></a>
## Appendix 1:Python Code Specification
The Python code of PaddleOCR follows [PEP8 Specification]( https://www.python.org/dev/peps/pep-0008/ ), some of the key concerns include the following
- Space
- Spaces should be added after commas, semicolons, colons, not before them
```python
# true:
print(x, y)
# false:
print(x , y)
```
- When specifying a keyword parameter or default parameter value in a function, do not use spaces on both sides of it
```python
# true:
def complex(real, imag=0.0)
# false:
def complex(real, imag = 0.0)
```
- comment
- Inline comments: inline comments are indicated by the` # `sign. Two spaces should be left between code and` # `, and one space should be left between` # `and comments, for example
```python
x = x + 1 # Compensate for border
```
- Functions and methods: The definition of each function should include the following:
- Function description: Utility, input and output of function
- Args: Name and description of each parameter
- Returns: The meaning and type of the return value
```python
def fetch_bigtable_rows(big_table, keys, other_silly_variable=None):
"""Fetches rows from a Bigtable.
Retrieves rows pertaining to the given keys from the Table instance
represented by big_table. Silly things may happen if
other_silly_variable is not None.
Args:
big_table: An open Bigtable Table instance.
keys: A sequence of strings representing the key of each table row
to fetch.
other_silly_variable: Another optional variable, that has a much
longer name than the other args, and which does nothing.
Returns:
A dict mapping keys to the corresponding table row data
fetched. Each row is represented as a tuple of strings. For
example:
{'Serak': ('Rigel VII', 'Preparer'),
'Zim': ('Irk', 'Invader'),
'Lrrr': ('Omicron Persei 8', 'Emperor')}
If a key from the keys argument is missing from the dictionary,
then that row was not found in the table.
"""
pass
```
<a name="Appendix2"></a>
## Appendix 2: Document Specification
### 2.1 Overall Description
- Document Location: If you add new features to your original Markdown file, please **Do not re-create** a new file. If you don't know where to add it, you can first PR the code and then ask the official in commit.
- New Markdown Document Name: Describe the content of the document in English, typically a combination of lowercase letters and underscores, such as `add_New_Algorithm.md`
- New Markdown Document Format: Catalog - Body - FAQ
> The directory generation method can use [this site](https://ecotrust-canada.github.io/markdown-toc/ ) Automatically extract directories after copying MD contents, and then add `<a name='XXXX'></a> before each heading of the MD file
- English and Chinese: Any changes or additions to the document need to be made in both Chinese and English documents.
### 2.2 Format Specification
- Title format: The document title format follows the format of: Arabic decimal point combination-space-title (for example, `2.1 XXXX`, `2.XXXX`)
- Code block: Displays code in code block format that needs to be run, describing the meaning of command parameters before the code block. for example:
> Pipeline of detection + direction Classify + recognition: Vertical text can be recognized after set direction classifier parameters`--use_angle_cls true`.
>
> ```
> paddleocr --image_dir ./imgs/11.jpg --use_angle_cls true
> ```
- Variable Rrferences: If code variables or command parameters are referenced in line, they need to be represented in line code, for example, above `--use_angle_cls true` with one space in front and one space in back
- Uniform naming: e.g. PP-OCRv2, PP-OCR mobile, `paddleocr` whl package, PPOCRLabel, Paddle Lite, etc.
- Supplementary notes: Supplementary notes by reference format `>`.
- Picture: If a picture is added to the description document, specify the naming of the picture (describing its content) and add the picture under `doc/`.
- Title: Capitalize the first letter of each word in the title.
<a name="Appendix3"></a>
## Appendix 3: Pull Request Description
### 3.1 PaddleOCR Branch Description
PaddleOCR will maintain two branches in the future, one for each:
- release/x.x family branch: stable release version branch, also the default branch. PaddleOCR releases a new release branch based on feature updates and adapts to the release version of Paddle. As versions iterate, more and more release/x.x family branches are maintained by default with the latest version of the release branch.
- dygraph branch: For the development branch, adapts the dygraph version of the Paddle dynamic graph to primarily develop new functionality. If you need to redevelop, choose the dygraph branch. To ensure that the dygraph branch pulls out the release/x.x branch when needed, the code for the dygraph branch can only use the valid API in the latest release branch of Paddle. That is, if a new API has been developed in the Paddle dygraph branch but has not yet appeared in the release branch code, do not use it in Paddle OCR. In addition, performance optimization, parameter tuning, policy updates that do not involve API can be developed normally.
The historical branch of PaddleOCR will no longer be maintained in the future. These branches will continue to be maintained, considering that some of you may still be using them:
- Develop branch: This branch was used for the development and testing of static diagrams and is currently compatible with version >=1.7. If you have special needs, you can also use this branch to accommodate older versions of Paddle, but you won't update your code until you fix the bug.
PaddleOCR welcomes you to actively contribute code to repo. Here are some basic processes for contributing code.
### 3.2 PaddleOCR Code Submission Process And Specification
> If you are familiar with Git use, you can jump directly to [Some Conventions For Submitting Code in 3.2.10](#Some_conventions_for_submitting_code)
#### 3.2.1 Create Your `Remote Repo`
- In PaddleOCR [GitHub Home]( https://github.com/PaddlePaddle/PaddleOCR ) Click the `Fork` button in the upper left corner to create a `remote repo`in your personal directory, such as ` https://github.com/ {your_name}/PaddleOCR`.
![banner](../banner.png)
- Clone `Remote repo`
```
# pull code of develop branch
git clone https://github.com/{your_name}/PaddleOCR.git -b dygraph
cd PaddleOCR
```
> Clone failures are mostly due to network reasons, try again later or configure the proxy
#### 3.2.2 Login And Connect Using Token
Start by viewing the information for the current `remote repo`.
```
git remote -v
# origin https://github.com/{your_name}/PaddleOCR.git (fetch)
# origin https://github.com/{your_name}/PaddleOCR.git (push)
```
Only the information of the clone `remote repo`, i.e. the PaddleOCR under your username, is available. Due to the change in Github's login method, you need to reconfigure the `remote repo` address by means of a Token. The token is generated as follows:
1. Find Personal Access Tokens: Click on your avatar in the upper right corner of the Github page and choose Settings --> Developer settings --> Personal access tokens,
2. Click Generate new token: Fill in the token name in Note, such as 'paddle'. In Select scopes, select repo (required), admin:repo_hook, delete_repo, etc. You can check them according to your needs. Then click Generate token to generate the token, and finally copy the generated token.
Delete the original origin configuration
```
git remote rm origin
```
Change the remote branch to `https://oauth2:{token}@github.com/{your_name}/PaddleOCR.git`. For example, if the token value is 12345 and your user name is PPOCR, run the following command
```
git remote add origin https://oauth2:12345@github.com/PPOCR/PaddleOCR.git
```
This establishes a connection to our own `remote repo`. Next we create a remote host of the original PaddleOCR repo, named upstream.
```
git remote add upstream https://github.com/PaddlePaddle/PaddleOCR.git
```
Use `git remote -v` to view current `remote warehouse` information, output as follows, found to include two origin and two upstream of `remote repo` .
```
origin https://github.com/{your_name}/PaddleOCR.git (fetch)
origin https://github.com/{your_name}/PaddleOCR.git (push)
upstream https://github.com/PaddlePaddle/PaddleOCR.git (fetch)
upstream https://github.com/PaddlePaddle/PaddleOCR.git (push)
```
This is mainly to keep the local repository up to date when subsequent pull request (PR) submissions are made.
#### 3.2.3 Create Local Branch
First get the latest code of upstream, then create a new_branch branch based on the dygraph of the upstream repo (upstream).
```
git fetch upstream
git checkout -b new_branch upstream/dygraph
```
> If for a newly forked PaddleOCR project, the user's remote repo (origin) has the same branch updates as the upstream repository (upstream), you can also create a new local branch based on the default branch of the origin repo or a specified branch with the following command
>
> ```
> # Create new_branch branch on user remote repo (origin) based on develop branch
> git checkout -b new_branch origin/develop
> # Create new_branch branch based on upstream remote repo develop branch
> # If you need to create a new branch from upstream,
> # you need to first use git fetch upstream to get upstream code
> git checkout -b new_branch upstream/develop
> ```
The final switch to the new branch is displayed with the following output information.
```
Branch new_branch set up to track remote branch develop from upstream.
Switched to a new branch 'new_branch'
```
After switching branches, file changes can be made on this branch
#### 3.2.4 Use Pre-Commit Hook
Paddle developers use the pre-commit tool to manage Git pre-submit hooks. It helps us format the source code (C++, Python) and automatically check for basic things (such as having only one EOL per file, not adding large files to Git) before committing it.
The pre-commit test is part of the unit test in Travis-CI. PR that does not satisfy the hook cannot be submitted to PaddleOCR. Install it first and run it in the current directory:
```
pip install pre-commit
pre-commit install
```
> 1. Paddle uses clang-format to adjust the C/C++ source code format. Make sure the `clang-format` version is above 3.8.
>
> 2. Yapf installed through pip install pre-commit is slightly different from conda install-c conda-forge pre-commit, and PaddleOCR developers use `pip install pre-commit`.
#### 3.2.5 Modify And Submit Code
If you make some changes on `README.Md ` on PaddleOCR, you can view the changed file through `git status`, and then add the changed file using `git add`。
```
git status # View change files
git add README.md
pre-commit
```
Repeat these steps until the pre-comit format check does not error. As shown below.
![img](../precommit_pass.png)
Use the following command to complete the submission.
```
git commit -m "your commit info"
```
#### 3.2.6 Keep Local Repo Up To Date
Get the latest code for upstream and update the current branch. Here the upstream comes from section 2.2, `Connecting to a remote repo`.
```
git fetch upstream
# If you want to commit to another branch, you need to pull code from another branch of upstream, here is develop
git pull upstream develop
```
#### 3.2.7 Push To Remote Repo
```
git push origin new_branch
```
#### 3.2.7 Submit Pull Request
Click the new pull request to select the local branch and the target branch, as shown in the following figure. In the description of PR, fill in the functions completed by the PR. Next, wait for review, and if you need to modify something, update the corresponding branch in origin with the steps above.
![banner](../pr.png)
#### 3.2.8 Sign CLA Agreement And Pass Unit Tests
- Signing the CLA When submitting a Pull Request to PaddlePaddle for the first time, you need to sign a CLA (Contributor License Agreement) agreement to ensure that your code can be incorporated as follows:
1. Please check the Check section in PR, find the license/cla, and click on the right detail to enter the CLA website
2. Click Sign in with GitHub to agree on the CLA website and when clicked, it will jump back to your Pull Request page
#### 3.2.9 Delete Branch
- Remove remote branch
After PR is merged into the main repo, we can delete the branch of the remote repofrom the PR page.
You can also use `git push origin:branch name` to delete remote branches, such as:
```
git push origin :new_branch
```
- Delete local branch
```
# Switch to the development branch, otherwise the current branch cannot be deleted
git checkout develop
# Delete new_ Branch Branch
git branch -D new_branch
```
<a name="Some_conventions_for_submitting_code"></a>
#### 3.2.10 Some Conventions For Submitting Code
In order for official maintainers to better focus on the code itself when reviewing it, please follow the following conventions each time you submit your code:
1)Please ensure that the unit tests in Travis-CI pass smoothly. If not, indicate that there is a problem with the submitted code, and the official maintainer generally does not review it.
2)Before submitting a Pull Request.
- Note the number of commits.
Reason: If you only modify one file and submit more than a dozen commits, each commit will only make a few modifications, which can be very confusing to the reviewer. The reviewer needs to look at each commit individually to see what changes have been made, and does not exclude the fact that changes between commits overlap each other.
Suggestion: Keep as few commits as possible each time you submit, and supplement your last commit with git commit --amend. For multiple commits that have been Push to a remote warehouse, you can refer to [squash commits after push](https://stackoverflow.com/questions/5667884/how-to-squash-commits-in-git-after-they-have-been-pushed ).
- Note the name of each commit: it should reflect the content of the current commit, not be too arbitrary.
3) If you have solved a problem, add in the first comment box of the Pull Request:fix #issue_number,This will automatically close the corresponding Issue when the Pull Request is merged. Key words include:close, closes, closed, fix, fixes, fixed, resolve, resolves, resolved,please choose the right vocabulary. Detailed reference [Closing issues via commit messages](https://help.github.com/articles/closing-issues-via-commit-messages).
In addition, in response to the reviewer's comments, you are requested to abide by the following conventions:
1) Each review comment from an official maintainer would like a response, which would better enhance the contribution of the open source community.
- If you agree to the review opinion and modify it accordingly, give a simple Done.
- If you disagree with the review, please give your own reasons for refuting.
2)If there are many reviews:
- Please give an overview of the changes.
- Please reply with `start a review', not directly. The reason is that each reply sends an e-mail message, which can cause a mail disaster.
# COMMUNITY CONTRIBUTION
Thank you for your support and interest in PaddleOCR. The goal of PaddleOCR is to build a professional, harmonious and supportive open source community with developers. This document presents existing community contributions, explanations for various contributions, and new opportunities and processes to make the contribution process more efficient and clear.
PaddleOCR wants to help any developer with a dream realize their vision and enjoy the joy of creating value through the power of AI.
---
<a href="https://github.com/PaddlePaddle/PaddleOCR/graphs/contributors">
<img src="https://contrib.rocks/image?repo=PaddlePaddle/PaddleOCR" />
</a>
> The picture above shows PaddleOCR's current Contributor, updated regularly
## 1. COMMUNITY CONTRIBUTION
### 1.1 PaddleOCR BASED COMMUNITY PROJECT
- 【The lastest】 [FastOCRLabel](https://gitee.com/BaoJianQiang/FastOCRLabel): Complete C# version annotation tool (@ [包建强](https://gitee.com/BaoJianQiang) )
#### 1.1.1 UNIVERSAL TOOL
- [DangoOCR offline version](https://github.com/PantsuDango/DangoOCR):Universal desktop instant translation tool (@ [PantsuDango](https://github.com/PantsuDango))
- [scr2txt](https://github.com/lstwzd/scr2txt):Screenshot to Text tool (@ [lstwzd](https://github.com/lstwzd))
- [AI Studio project](https://aistudio.baidu.com/aistudio/projectdetail/1054614?channelType=0&channel=0):English video automatically generates subtitles( @ [叶月水狐](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/322052))
#### 1.1.2 VERTICAL SCENE TOOLS
- [id_card_ocr](https://github.com/baseli/id_card_ocr):Identification of copy of ID card(@ [baseli](https://github.com/baseli))
- [Paddle_Table_Image_Reader](https://github.com/thunder95/Paddle_Table_Image_Reader): A data assistant that can read tables and pictures(@ [thunder95](https://github.com/thunder95]))
#### 1.1.3 PRE AND POST PROCESSING
- [paddleOCRCorrectOutputs](https://github.com/yuranusduke/paddleOCRCorrectOutputs):Get the key-value of OCR recognition result (@ [yuranusduke](https://github.com/yuranusduke))
### 1.2 NEW FEATURES FOR PaddleOCR
- Thanks [authorfu](https://github.com/authorfu) for contributing Android([#340](https://github.com/PaddlePaddle/PaddleOCR/pull/340)) and [xiadeye](https://github.com/xiadeye) for contributing IOS demo code([#325](https://github.com/PaddlePaddle/PaddleOCR/pull/325)).
- Thanks [tangmq](https://gitee.com/tangmq) for adding docker deployment service to PaddleOCR to support quick release of callable restful API services([#507](https://github.com/PaddlePaddle/PaddleOCR/pull/507)).
- Thanks [lijinhan](https://github.com/lijinhan) for adding Java springboot to PaddleOCR and call OCR hubserving interface to complete the use of OCR service deployment([#1027](https://github.com/PaddlePaddle/PaddleOCR/pull/1027)).
- Thanks [Evezerest](https://github.com/Evezerest), [ninetailskim](https://github.com/ninetailskim), [edencfc](https://github.com/edencfc), [BeyondYourself](https://github.com/BeyondYourself), [1084667371](https://github.com/1084667371) for contributing complete code of [PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/PPOCRLabel/README_ch.md).
### 1.3 CODE AND DOCUMENT OPTIMIZATION
- Thanks [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) for contributing new visualization methods and adding .gitgnore, handling the problem of manually setting the PYTHONPATH environment variable([#210](https://github.com/PaddlePaddle/PaddleOCR/pull/210)).
- Thanks [lyl120117](https://github.com/lyl120117) for contributing code to print network structure([#304](https://github.com/PaddlePaddle/PaddleOCR/pull/304)).
- Thanks [BeyondYourself](https://github.com/BeyondYourself) for making a lot of great suggestions for PaddleOCR and simplifying some code styles of paddleocr([so many commits)](https://github.com/PaddlePaddle/PaddleOCR/commits?author=BeyondYourself).
- Thanks [Khanh Tran](https://github.com/xxxpsyduck) and [Karl Horky](https://github.com/karlhorky) for contributing modifing English documents.
### 1.4 MULTILINGUAL CORPUS
- Thanks [xiangyubo](https://github.com/xiangyubo) for contributing handwritting Chinese OCR dataset([#321](https://github.com/PaddlePaddle/PaddleOCR/pull/321)).
- Thanks [Mejans](https://github.com/Mejans) for contributing dictionary and corpus of the new language Occitan to PaddleOCR([#954](https://github.com/PaddlePaddle/PaddleOCR/pull/954)).
## 2. CONTRIBUTION ILLUSTRATING
### 2.1 NEW FUNCTION CLASS
PaddleOCR welcomes community contributions to various services, deployment examples and software applications with paddleOCR as the core. Certified community contributions will be added to the above community contribution table to increase exposure for the majority of developers, which is also the glory of PaddleOCR, including:
- Project form: the project code certified by the official community shall have good specifications and structure, and shall be equipped with a detailed README.md, which describes how to use the project. Through add a line 'paddleocr' to the requirements.txt, which can be automatically included in the usedby of paddleocr.
- Integration method: if it is an update to the existing PaddleOCR tool, it will be integrated into the main repo. If a new function is expanded for paddleocr, please contact the official personnel first to confirm whether the project is integrated into the master repo, *even if the new function is not integrated into the master repo, we will also increase the exposure of your personal project in the way of community contribution.*
### 2.2 CODE OPTIMIZATION
If you encounter code bugs and unexpected functions when using PaddleOCR, you can contribute your modifications to PaddleOCR, including:
- Python code specifications are available for reference [Appendix 1:Python code specifications](./code_and_doc.md/#Appendix1).
- Before submitting the code, please confirm again and again that no new bugs will be introduced, and describe the optimization points in the PR. If the PR solves an issue, please connect to the issue in the PR. All PR shall comply with the requirements in Appendix [3.2.10 Some conventions for submitting code.](./code_and_doc.md/#Some conventions for submitting code)
- Please refer to the below before submitting. If you are not familiar with the git submission process, you can also refer to Section 3.2 of [Appendix 3: description of Pull Request](./code_and_doc.md/#Appendix3).If you are not familiar with the git submission process, you can also refer to Section 3.2 of Appendix 3.
**Finally, please add the label Third Party in the title of PR and @ Everest in the description , PR with this label will be treated with high priority`[third-part]`.**
### 2.3 DOCUMENT OPTIMIZATION
If you encounter problems such as unclear document description, missing description and invalid link when using PaddleOCR, you can contribute your modifications to PaddleOCR. For document writing specifications, please refer to [Appendix 2: document specifications](./code_and_doc.md/#Appendix2). **Finally, please add the label Third Party in the title of PR and @ Everest in the description , PR with this label will be treated with high priority`[third-party].**
## 3. MORE CONTRIBUTION OPPORTUNITIES
We encourage developers to use PaddleOCR to realize their ideas. At the same time, we also list some valuable development directions after analysis, which are collected in the regular season of community projects as a whole.
## 4. CONTACT US
We very much welcome developers to contact us before they intend to contribute code, documents, corpus and other contents to PaddleOCR, which can greatly reduce the communication cost in the PR process. At the same time, if you find some ideas difficult to realize personally, we can also recruit like-minded developers for the project in the form of SIG. Projects funded through SIG channels will receive deep R &amp; D support and operational resources (such as official account publicity, live broadcast lessons, etc.).
Our recommended contribution process is:
- By adding the `[Third Party]` mark in the topic of GitHub issue, explain the problems encountered (and the ideas to solve) or the functions to be expanded, and wait for the reply of the person on duty. For example, ` [Third Party] contributes IOS examples to PaddleOCR`.
- After communicating with us and confirming that the technical scheme or bugs and optimization points are correct, add functions or modify them accordingly, and the codes and documents shall comply with relevant specifications.
- PR links to the above issue and waits for review.
## 5. THANKS AND FOLLOW-UP
- After the code is combined, the information will be updated in the first section of this document. The default link is GitHub name and home page. If you need to change the home page, you can also contact us.
- New important function classes will be advertised in the user group and enjoy the honor of the open source community.
- **If you have a PaddleOCR based project that does not appear in the above list, follow `4. CONTACT US` .**
# Configuration # Configuration
- [1. Optional Parameter List](#1-optional-parameter-list) - [1. Optional Parameter List](#1-optional-parameter-list)
- [2. Intorduction to Global Parameters of Configuration File](#2-intorduction-to-global-parameters-of-configuration-file) - [2. Introduction to Global Parameters of Configuration File](#2-introduction-to-global-parameters-of-configuration-file)
- [3. Multilingual Config File Generation](#3-multilingual-config-file-generation) - [3. Multilingual Config File Generation](#3-multilingual-config-file-generation)
<a name="1-optional-parameter-list"></a> <a name="1-optional-parameter-list"></a>
...@@ -15,9 +15,9 @@ The following list can be viewed through `--help` ...@@ -15,9 +15,9 @@ The following list can be viewed through `--help`
| -c | ALL | Specify configuration file to use | None | **Please refer to the parameter introduction for configuration file usage** | | -c | ALL | Specify configuration file to use | None | **Please refer to the parameter introduction for configuration file usage** |
| -o | ALL | set configuration options | None | Configuration using -o has higher priority than the configuration file selected with -c. E.g: -o Global.use_gpu=false | | -o | ALL | set configuration options | None | Configuration using -o has higher priority than the configuration file selected with -c. E.g: -o Global.use_gpu=false |
<a name="2-intorduction-to-global-parameters-of-configuration-file"></a> <a name="2-introduction-to-global-parameters-of-configuration-file"></a>
## 2. Intorduction to Global Parameters of Configuration File ## 2. Introduction to Global Parameters of Configuration File
Take rec_chinese_lite_train_v2.0.yml as an example Take rec_chinese_lite_train_v2.0.yml as an example
### Global ### Global
...@@ -30,7 +30,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example ...@@ -30,7 +30,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| print_batch_step | Set print log interval | 10 | \ | | print_batch_step | Set print log interval | 10 | \ |
| save_model_dir | Set model save path | output/{算法名称} | \ | | save_model_dir | Set model save path | output/{算法名称} | \ |
| save_epoch_step | Set model save interval | 3 | \ | | save_epoch_step | Set model save interval | 3 | \ |
| eval_batch_step | Set the model evaluation interval | 2000 or [1000, 2000] | runing evaluation every 2000 iters or evaluation is run every 2000 iterations after the 1000th iteration | | eval_batch_step | Set the model evaluation interval | 2000 or [1000, 2000] | running evaluation every 2000 iters or evaluation is run every 2000 iterations after the 1000th iteration |
| cal_metric_during_train | Set whether to evaluate the metric during the training process. At this time, the metric of the model under the current batch is evaluated | true | \ | | cal_metric_during_train | Set whether to evaluate the metric during the training process. At this time, the metric of the model under the current batch is evaluated | true | \ |
| load_static_weights | Set whether the pre-training model is saved in static graph mode (currently only required by the detection algorithm) | true | \ | | load_static_weights | Set whether the pre-training model is saved in static graph mode (currently only required by the detection algorithm) | true | \ |
| pretrained_model | Set the path of the pre-trained model | ./pretrain_models/CRNN/best_accuracy | \ | | pretrained_model | Set the path of the pre-trained model | ./pretrain_models/CRNN/best_accuracy | \ |
...@@ -65,7 +65,7 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck ...@@ -65,7 +65,7 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
| Parameter | Use | Defaults | Note | | Parameter | Use | Defaults | Note |
| :---------------------: | :---------------------: | :--------------: | :--------------------: | | :---------------------: | :---------------------: | :--------------: | :--------------------: |
| model_type | Network Type | rec | Currently support`rec`,`det`,`cls` | | model_type | Network Type | rec | Currently support`rec`,`det`,`cls` |
| algorithm | Model name | CRNN | See [algorithm_overview](./algorithm_overview.md) for the support list | | algorithm | Model name | CRNN | See [algorithm_overview](./algorithm_overview_en.md) for the support list |
| **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transform](../../ppocr/modeling/transform) for details | | **Transform** | Set the transformation method | - | Currently only recognition algorithms are supported, see [ppocr/modeling/transform](../../ppocr/modeling/transform) for details |
| name | Transformation class name | TPS | Currently supports `TPS` | | name | Transformation class name | TPS | Currently supports `TPS` |
| num_fiducial | Number of TPS control points | 20 | Ten on the top and bottom | | num_fiducial | Number of TPS control points | 20 | Ten on the top and bottom |
...@@ -134,14 +134,14 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck ...@@ -134,14 +134,14 @@ In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck
## 3. Multilingual Config File Generation ## 3. Multilingual Config File Generation
PaddleOCR currently supports 80 (except Chinese) language recognition. A multi-language configuration file template is PaddleOCR currently supports recognition for 80 languages (besides Chinese). A multi-language configuration file template is
provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml) provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)
There are two ways to create the required configuration file: There are two ways to create the required configuration file:
1. Automatically generated by script 1. Automatically generated by script
[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) Can help you generate configuration files for multi-language models Script [generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) can help you generate configuration files for multi-language models.
- Take Italian as an example, if your data is prepared in the following format: - Take Italian as an example, if your data is prepared in the following format:
``` ```
...@@ -196,21 +196,21 @@ Italian is made up of Latin letters, so after executing the command, you will ge ...@@ -196,21 +196,21 @@ Italian is made up of Latin letters, so after executing the command, you will ge
epoch_num: 500 epoch_num: 500
... ...
character_dict_path: {path/of/dict} # path of dict character_dict_path: {path/of/dict} # path of dict
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/ # root directory of training data data_dir: train_data/ # root directory of training data
label_file_list: ["./train_data/train_list.txt"] # train label path label_file_list: ["./train_data/train_list.txt"] # train label path
... ...
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/ # root directory of val data data_dir: train_data/ # root directory of val data
label_file_list: ["./train_data/val_list.txt"] # val label path label_file_list: ["./train_data/val_list.txt"] # val label path
... ...
``` ```
......
...@@ -22,7 +22,7 @@ For more details about data preparation and training tutorials, refer to the doc ...@@ -22,7 +22,7 @@ For more details about data preparation and training tutorials, refer to the doc
PaddleOCR provides a concatenation tool for detection and recognition models, which can connect any trained detection model and any recognition model into a two-stage text recognition system. The input image goes through four main stages: text detection, text rectification, text recognition, and score filtering to output the text position and recognition results, and at the same time, you can choose to visualize the results. PaddleOCR provides a concatenation tool for detection and recognition models, which can connect any trained detection model and any recognition model into a two-stage text recognition system. The input image goes through four main stages: text detection, text rectification, text recognition, and score filtering to output the text position and recognition results, and at the same time, you can choose to visualize the results.
When performing prediction, you need to specify the path of a single image or a image folder through the parameter `image_dir`, the parameter `det_model_dir` specifies the path of detection model, and the parameter `rec_model_dir` specifies the path of recogniton model. The visualized results are saved to the `./inference_results` folder by default. When performing prediction, you need to specify the path of a single image or a image folder through the parameter `image_dir`, the parameter `det_model_dir` specifies the path of detection model, and the parameter `rec_model_dir` specifies the path of recognition model. The visualized results are saved to the `./inference_results` folder by default.
``` ```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/det/" --rec_model_dir="./inference/rec/"
......
...@@ -4,7 +4,7 @@ This section uses the icdar2015 dataset as an example to introduce the training, ...@@ -4,7 +4,7 @@ This section uses the icdar2015 dataset as an example to introduce the training,
- [1. Data and Weights Preparation](#1-data-and-weights-preparatio) - [1. Data and Weights Preparation](#1-data-and-weights-preparatio)
* [1.1 Data Preparation](#11-data-preparation) * [1.1 Data Preparation](#11-data-preparation)
* [1.2 Download Pretrained Model](#12-download-pretrained-model) * [1.2 Download Pre-trained Model](#12-download-pretrained-model)
- [2. Training](#2-training) - [2. Training](#2-training)
* [2.1 Start Training](#21-start-training) * [2.1 Start Training](#21-start-training)
* [2.2 Load Trained Model and Continue Training](#22-load-trained-model-and-continue-training) * [2.2 Load Trained Model and Continue Training](#22-load-trained-model-and-continue-training)
...@@ -45,7 +45,7 @@ After decompressing the data set and downloading the annotation file, PaddleOCR/ ...@@ -45,7 +45,7 @@ After decompressing the data set and downloading the annotation file, PaddleOCR/
└─ test_icdar2015_label.txt Test annotation of icdar dataset └─ test_icdar2015_label.txt Test annotation of icdar dataset
``` ```
The provided annotation file format is as follow, seperated by "\t": The provided annotation file format is as follow, separated by "\t":
``` ```
" Image file name Image annotation information encoded by json.dumps" " Image file name Image annotation information encoded by json.dumps"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}] ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
...@@ -59,19 +59,19 @@ The `points` in the dictionary represent the coordinates (x, y) of the four poin ...@@ -59,19 +59,19 @@ The `points` in the dictionary represent the coordinates (x, y) of the four poin
If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format. If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format.
### 1.2 Download Pretrained Model ### 1.2 Download Pre-trained Model
First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.0/ppcls/modeling/architectures) to replace backbone according to your needs. First download the pre-trained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.0/ppcls/modeling/architectures) to replace backbone according to your needs.
And the responding download link of backbone pretrain weights can be found in (https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/README_cn.md#resnet%E5%8F%8A%E5%85%B6vd%E7%B3%BB%E5%88%97). And the responding download link of backbone pre-trained weights can be found in (https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/README_cn.md#resnet%E5%8F%8A%E5%85%B6vd%E7%B3%BB%E5%88%97).
```shell ```shell
cd PaddleOCR/ cd PaddleOCR/
# Download the pre-trained model of MobileNetV3 # Download the pre-trained model of MobileNetV3
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams
# or, download the pre-trained model of ResNet18_vd # or, download the pre-trained model of ResNet18_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams
# or, download the pre-trained model of ResNet50_vd # or, download the pre-trained model of ResNet50_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams
``` ```
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## Introduction ## Introduction
The high performance of distributed training is one of the core advantages of PaddlePaddle. In the classification task, distributed training can achieve almost linear speedup ratio. Generally, OCR training task need massive training data. Such as recognition, ppocrv2.0 model is trained based on 1800W dataset, which is very time-consuming if using single machine. Therefore, the distributed training is used in paddleocr to speedup the training task. For more information about distributed training, please refer to [distributed training quick start tutorial](https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/ps_quick_start.html). The high performance of distributed training is one of the core advantages of PaddlePaddle. In the classification task, distributed training can achieve almost linear speedup ratio. Generally, OCR training task need massive training data. Such as recognition, PP-OCR v2.0 model is trained based on 1800W dataset, which is very time-consuming if using single machine. Therefore, the distributed training is used in PaddleOCR to speedup the training task. For more information about distributed training, please refer to [distributed training quick start tutorial](https://fleet-x.readthedocs.io/en/latest/paddle_fleet_rst/parameter_server/ps_quick_start.html).
## Quick Start ## Quick Start
...@@ -35,7 +35,7 @@ python3 -m paddle.distributed.launch \ ...@@ -35,7 +35,7 @@ python3 -m paddle.distributed.launch \
**Notice:** **Notice:**
* The IP addresses of different machines need to be separated by commas, which can be queried through `ifconfig` or `ipconfig`. * The IP addresses of different machines need to be separated by commas, which can be queried through `ifconfig` or `ipconfig`.
* Different machines need to be set to be secret free and can `ping` success with others directly, otherwise communication cannot establish between them. * Different machines need to be set to be secret free and can `ping` success with others directly, otherwise communication cannot establish between them.
* The code, data and start command betweent different machines must be completely consistent and then all machines need to run start command. The first machine in the `ip_list` is set to `trainer0`, and so on. * The code, data and start command between different machines must be completely consistent and then all machines need to run start command. The first machine in the `ip_list` is set to `trainer0`, and so on.
## Performance comparison ## Performance comparison
......
# Enhanced CTC Loss
In OCR recognition, CRNN is a text recognition algorithm widely applied in the industry. In the training phase, it uses CTCLoss to calculate the network loss. In the inference phase, it uses CTCDecode to obtain the decoding result. Although the CRNN algorithm has been proven to achieve reliable recognition results in actual business, users have endless requirements for recognition accuracy. So how to improve the accuracy of text recognition? Taking CTCLoss as the starting point, this paper explores the improved fusion scheme of CTCLoss from three different perspectives: Hard Example Mining, Multi-task Learning, and Metric Learning. Based on the exploration, we propose EnhancedCTCLoss, which includes the following 3 components: Focal-CTC Loss, A-CTC Loss, C-CTC Loss.
## 1. Focal-CTC Loss
Focal Loss was proposed by the paper, "[Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)". When the loss was first proposed, it was mainly to solve the problem of a serious imbalance in the ratio of positive and negative samples in one-stage target detection. This loss function reduces the weight of a large number of simple negative samples in training and also can be understood as a kind of difficult sample mining.
The form of the loss function is as follows:
<div align="center">
<img src="./focal_loss_formula.png" width = "600" />
</div>
Among them, y' is the output of the activation function, and the value is between 0-1. It adds a modulation factor (1-y’)^&gamma; and a balance factor &alpha; on the basis of the original cross-entropy loss. When &alpha; = 1, y = 1, the comparison between the loss function and the cross-entropy loss is shown in the following figure:
<div align="center">
<img src="./focal_loss_image.png" width = "600" />
</div>
As can be seen from the above figure, when &gamma; > 0, the adjustment coefficient (1-y’)^&gamma; gives smaller weight to the easy-to-classify sample loss, making the network pay more attention to the difficult and misclassified samples. The adjustment factor &gamma; is used to adjust the rate at which the weight of simple samples decreases. When &gamma; = 0, it is the cross-entropy loss function. When &gamma; increases, the influence of the adjustment factor will also increase. Experiments revealed that 2 is the optimal value of &gamma;. The balance factor &alpha; is used to balance the uneven proportions of the positive and negative samples. In the text, &alpha; is taken as 0.25.
For the classic CTC algorithm, suppose a certain feature sequence (f<sub>1</sub>, f<sub>2</sub>, ......f<sub>t</sub>), after CTC decoding, the probability that the result is equal to label is y', then the probability that the CTC decoding result is not equal to label is (1-y'); it is not difficult to find that the CTCLoss value and y' have the following relationship:
<div align="center">
<img src="./equation_ctcloss.png" width = "250" />
</div>
Combining the idea of Focal Loss, assigning larger weights to difficult samples and smaller weights to simple samples can make the network focus more on the mining of difficult samples and further improve the accuracy of recognition. Therefore, we propose Focal-CTC Loss. Its definition is as follows:
<div align="center">
<img src="./equation_focal_ctc.png" width = "500" />
</div>
In the experiment, the value of &gamma; is 2, &alpha; = 1, see this for specific implementation: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py)
## 2. A-CTC Loss
A-CTC Loss is short for CTC Loss + ACE Loss. Among them, ACE Loss was proposed by the paper, “[Aggregation Cross-Entropy for Sequence Recognition](https://arxiv.org/abs/1904.08364)”. Compared with CTCLoss, ACE Loss has the following two advantages:
+ ACE Loss can solve the recognition problem of 2-D text, while CTCLoss can only process 1-D text
+ ACE Loss is better than CTC loss in time complexity and space complexity
The advantages and disadvantages of the OCR recognition algorithm summarized by the predecessors are shown in the following figure:
<div align="center">
<img src="./rec_algo_compare.png" width = "1000" />
</div>
Although ACELoss does handle 2D predictions, as shown in the figure above, and has advantages in memory usage and inference speed, in practice, we found that using ACELoss alone, the recognition effect is not as good as CTCLoss. Consequently, we tried to combine CTCLoss and ACELoss, and CTCLoss is the mainstay while ACELoss acts as an auxiliary supervision loss. This attempt has achieved better results. On our internal experimental data set, compared to using CTCLoss alone, the recognition accuracy can be improved by about 1%.
A_CTC Loss is defined as follows:
<div align="center">
<img src="./equation_a_ctc.png" width = "300" />
</div>
In the experiment, λ = 0.1. See the ACE loss implementation code: [ace_loss.py](../../ppocr/losses/ace_loss.py)
## 3. C-CTC Loss
C-CTC Loss is short for CTC Loss + Center Loss. Among them, Center Loss was proposed by the paper, “[A Discriminative Feature Learning Approach for Deep Face Recognition](https://link.springer.com/chapter/10.1007/978-3-319-46478-7_31)“. It was first used in face recognition tasks to increase the distance between classes and reduce the distance within classes. It is an earlier and also widely used algorithm.
In the task of Chinese OCR recognition, through the analysis of bad cases, we found that a major difficulty in Chinese recognition is that there are many similar characters, which are easy to misunderstand. From this, we thought about whether we can learn from the idea of n to increase the class spacing of similar characters, to improve recognition accuracy. However, Metric Learning is mainly used in the field of image recognition, and the label of the training data is a fixed value; for OCR recognition, it is a sequence recognition task essentially, and there is no explicit alignment between features and labels. Therefore, how to combine the two is still a direction worth exploring.
By trying Arcmargin, Cosmargin and other methods, we finally found that Centerloss can help further improve the accuracy of recognition. C_CTC Loss is defined as follows:
<div align="center">
<img src="./equation_c_ctc.png" width = "300" />
</div>
In the experiment, we set λ=0.25. See the center_loss implementation code: [center_loss.py](../../ppocr/losses/center_loss.py)
It is worth mentioning that in C-CTC Loss, choosing to initialize the Center randomly does not bring significant improvement. Our Center initialization method is as follows:
+ Based on the original CTCLoss, a network N is obtained by training
+ Select the training set, identify the completely correct part, and form the set G
+ Send each sample in G to the network, perform forward calculation, and extract the correspondence between the input of the last FC layer (ie feature) and the result of argmax calculation (ie index)
+ Aggregate features with the same index, calculate the average, and get the initial center of each character.
Taking the configuration file `configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml` as an example, the center extraction command is as follows:
```
python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model="./output/rec_mobile_pp-OCRv2/best_accuracy"
```
After running, `train_center.pkl` will be generated in the main directory of PaddleOCR.
## 4. Experiment
For the above three solutions, we conducted training and evaluation based on Baidu's internal data set. The experimental conditions are shown in the following table:
| algorithm | Focal_CTC | A_CTC | C-CTC |
| :-------- | :-------- | ----: | :---: |
| gain | +0.3% | +0.7% | +1.7% |
Based on the above experimental conclusions, we adopted the C-CTC strategy in PP-OCRv2. It is worth mentioning that, because PP-OCRv2 deals with the recognition task of 6625 Chinese characters, the character set is relatively large and there are many similar characters, so the C-CTC solution brings a significant improvement on this task. But if you switch to other OCR recognition tasks, the conclusion may be different. You can try Focal-CTC, A-CTC, C-CTC, and the combined solution EnhancedCTC. We believe it will bring different degrees of improvement.
The unified combined plan is shown in the following file: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
\ No newline at end of file
...@@ -4,9 +4,9 @@ Windows and Mac users are recommended to use Anaconda to build a Python environm ...@@ -4,9 +4,9 @@ Windows and Mac users are recommended to use Anaconda to build a Python environm
Recommended working environment: Recommended working environment:
- PaddlePaddle >= 2.0.0 (2.1.2) - PaddlePaddle >= 2.0.0 (2.1.2)
- python3.7 - Python 3.7
- CUDA10.1 / CUDA10.2 - CUDA 10.1 / CUDA 10.2
- CUDNN 7.6 - cuDNN 7.6
* [1. Python Environment Setup](#1) * [1. Python Environment Setup](#1)
+ [1.1 Windows](#1.1) + [1.1 Windows](#1.1)
...@@ -25,7 +25,7 @@ Recommended working environment: ...@@ -25,7 +25,7 @@ Recommended working environment:
#### 1.1.1 Install Anaconda #### 1.1.1 Install Anaconda
- Note: To use paddlepaddle you need to install python environment first, here we choose python integrated environment Anaconda toolkit - Note: To use PaddlePaddle you need to install python environment first, here we choose python integrated environment Anaconda toolkit
- Anaconda is a common python package manager - Anaconda is a common python package manager
- After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment. - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment.
...@@ -44,19 +44,19 @@ Recommended working environment: ...@@ -44,19 +44,19 @@ Recommended working environment:
<img src="../install/windows/anaconda_install_folder.png" alt="install config" width="500" align=" left"/> <img src="../install/windows/anaconda_install_folder.png" alt="install config" width="500" align=" left"/>
- Check conda to add environment variables and ignore the warning that - Check Conda to add environment variables and ignore the warning that
<img src="../install/windows/anaconda_install_env.png" alt="add conda to path" width="500" align="center"/> <img src="../install/windows/anaconda_install_env.png" alt="add conda to path" width="500" align="center"/>
#### 1.1.2 Opening the terminal and creating the conda environment #### 1.1.2 Opening the terminal and creating the Conda environment
- Open Anaconda Prompt terminal: bottom left Windows Start Menu -> Anaconda3 -> Anaconda Prompt start console - Open Anaconda Prompt terminal: bottom left Windows Start Menu -> Anaconda3 -> Anaconda Prompt start console
<img src="../install/windows/anaconda_prompt.png" alt="anaconda download" width="300" align="center"/> <img src="../install/windows/anaconda_prompt.png" alt="anaconda download" width="300" align="center"/>
- Create a new conda environment - Create a new Conda environment
```shell ```shell
# Enter the following command at the command line to create an environment named paddle_env # Enter the following command at the command line to create an environment named paddle_env
...@@ -70,7 +70,7 @@ Recommended working environment: ...@@ -70,7 +70,7 @@ Recommended working environment:
<img src="../install/windows/conda_new_env.png" alt="conda create" width="700" align="center"/> <img src="../install/windows/conda_new_env.png" alt="conda create" width="700" align="center"/>
- To activate the conda environment you just created, enter the following command at the command line. - To activate the Conda environment you just created, enter the following command at the command line.
```shell ```shell
# Activate the paddle_env environment # Activate the paddle_env environment
...@@ -91,7 +91,7 @@ The above anaconda environment and python environment are installed ...@@ -91,7 +91,7 @@ The above anaconda environment and python environment are installed
#### 1.2.1 Installing Anaconda #### 1.2.1 Installing Anaconda
- Note: To use paddlepaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit - Note: To use PaddlePaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit
- Anaconda is a common python package manager - Anaconda is a common python package manager
- After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment
...@@ -108,17 +108,17 @@ The above anaconda environment and python environment are installed ...@@ -108,17 +108,17 @@ The above anaconda environment and python environment are installed
- Just follow the default settings, it will take a while to install - Just follow the default settings, it will take a while to install
- It is recommended to install a code editor such as vscode or pycharm - It is recommended to install a code editor such as VSCode or PyCharm
#### 1.2.2 Open a terminal and create a conda environment #### 1.2.2 Open a terminal and create a Conda environment
- Open the terminal - Open the terminal
- Press command and spacebar at the same time, type "terminal" in the focus search, double click to enter terminal - Press command and spacebar at the same time, type "terminal" in the focus search, double click to enter terminal
- **Add conda to the environment variables** - **Add Conda to the environment variables**
- Environment variables are added so that the system can recognize the conda command - Environment variables are added so that the system can recognize the Conda command
- Open `~/.bash_profile` in the terminal by typing the following command. - Open `~/.bash_profile` in the terminal by typing the following command.
...@@ -126,7 +126,7 @@ The above anaconda environment and python environment are installed ...@@ -126,7 +126,7 @@ The above anaconda environment and python environment are installed
vim ~/.bash_profile vim ~/.bash_profile
``` ```
- Add conda as an environment variable in `~/.bash_profile`. - Add Conda as an environment variable in `~/.bash_profile`.
```shell ```shell
# Press i first to enter edit mode # Press i first to enter edit mode
...@@ -156,12 +156,12 @@ The above anaconda environment and python environment are installed ...@@ -156,12 +156,12 @@ The above anaconda environment and python environment are installed
- When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit - When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit
- Verify that the conda command is recognized. - Verify that the Conda command is recognized.
- Enter `source ~/.bash_profile` in the terminal to update the environment variables - Enter `source ~/.bash_profile` in the terminal to update the environment variables
- Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then conda has been added to the environment variables - Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then Conda has been added to the environment variables
- Create a new conda environment - Create a new Conda environment
```shell ```shell
# Enter the following command at the command line to create an environment called paddle_env # Enter the following command at the command line to create an environment called paddle_env
...@@ -175,7 +175,7 @@ The above anaconda environment and python environment are installed ...@@ -175,7 +175,7 @@ The above anaconda environment and python environment are installed
- <img src="../install/mac/conda_create.png" alt="conda_create" width="600" align="center"/> - <img src="../install/mac/conda_create.png" alt="conda_create" width="600" align="center"/>
- To activate the conda environment you just created, enter the following command at the command line. - To activate the Conda environment you just created, enter the following command at the command line.
```shell ```shell
# Activate the paddle_env environment # Activate the paddle_env environment
...@@ -198,7 +198,7 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit ...@@ -198,7 +198,7 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
#### 1.3.1 Anaconda environment configuration #### 1.3.1 Anaconda environment configuration
- Note: To use paddlepaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit - Note: To use PaddlePaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit
- Anaconda is a common python package manager - Anaconda is a common python package manager
- After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment
...@@ -214,9 +214,9 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit ...@@ -214,9 +214,9 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
- Select the appropriate version for your operating system - Select the appropriate version for your operating system
- Type `uname -m` in the terminal to check the command set used by your system - Type `uname -m` in the terminal to check the command set used by your system
- Download method 1: Download locally, then transfer the installation package to the linux server - Download method 1: Download locally, then transfer the installation package to the Linux server
- Download method 2: Directly use linux command line to download - Download method 2: Directly use Linux command line to download
```shell ```shell
# First install wget # First install wget
...@@ -277,12 +277,12 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit ...@@ -277,12 +277,12 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
- When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit - When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit
- Verify that the conda command is recognized. - Verify that the Conda command is recognized.
- Enter `source ~/.bash_profile` in the terminal to update the environment variables - Enter `source ~/.bash_profile` in the terminal to update the environment variables
- Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then conda has been added to the environment variables - Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then Conda has been added to the environment variables
- Create a new conda environment - Create a new Conda environment
```shell ```shell
# Enter the following command at the command line to create an environment called paddle_env # Enter the following command at the command line to create an environment called paddle_env
...@@ -296,7 +296,7 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit ...@@ -296,7 +296,7 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
<img src="../install/linux/conda_create.png" alt="conda_create" width="500" align="center"/> <img src="../install/linux/conda_create.png" alt="conda_create" width="500" align="center"/>
- To activate the conda environment you just created, enter the following command at the command line. - To activate the Conda environment you just created, enter the following command at the command line.
```shell ```shell
# Activate the paddle_env environment # Activate the paddle_env environment
...@@ -335,13 +335,13 @@ sudo docker container exec -it ppocr /bin/bash ...@@ -335,13 +335,13 @@ sudo docker container exec -it ppocr /bin/bash
## 2. Install PaddlePaddle 2.0 ## 2. Install PaddlePaddle 2.0
- If you have cuda9 or cuda10 installed on your machine, please run the following command to install - If you have CUDA 9 or CUDA 10 installed on your machine, please run the following command to install
```bash ```bash
python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
``` ```
- If you only have cpu on your machine, please run the following command to install - If you have no available GPU on your machine, please run the following command to install the CPU version
```bash ```bash
python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
......
...@@ -139,7 +139,7 @@ tar xf ch_ppocr_mobile_v2.0_det_infer.tar ...@@ -139,7 +139,7 @@ tar xf ch_ppocr_mobile_v2.0_det_infer.tar
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/"
``` ```
The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows: The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
![](../imgs_results/det_res_00018069.jpg) ![](../imgs_results/det_res_00018069.jpg)
...@@ -244,7 +244,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -244,7 +244,7 @@ The visualized text detection results are saved to the `./inference_results` fol
<a name="RECOGNITION_MODEL_INFERENCE"></a> <a name="RECOGNITION_MODEL_INFERENCE"></a>
## 3. Text Recognition Model Inference ## 3. Text Recognition Model Inference
The following will introduce the lightweight Chinese recognition model inference, other CTC-based and Attention-based text recognition models inference. For Chinese text recognition, it is recommended to choose the recognition model based on CTC loss. In practice, it is also found that the result of the model based on Attention loss is not as good as the one based on CTC loss. In addition, if the characters dictionary is modified during training, make sure that you use the same characters set during inferencing. Please check below for details. The following will introduce the lightweight Chinese recognition model inference, other CTC-based and Attention-based text recognition models inference. For Chinese text recognition, it is recommended to choose the recognition model based on CTC loss. In practice, it is also found that the result of the model based on Attention loss is not as good as the one based on CTC loss. In addition, if the characters dictionary is modified during training, make sure that you use the same characters set during inference. Please check below for details.
<a name="LIGHTWEIGHT_RECOGNITION"></a> <a name="LIGHTWEIGHT_RECOGNITION"></a>
......
...@@ -7,7 +7,7 @@ This article introduces the use of the Python inference engine for the PP-OCR mo ...@@ -7,7 +7,7 @@ This article introduces the use of the Python inference engine for the PP-OCR mo
- [Text Detection Model Inference](#DETECTION_MODEL_INFERENCE) - [Text Detection Model Inference](#DETECTION_MODEL_INFERENCE)
- [Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE) - [Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
- [1. Lightweight Chinese Recognition Model Inference](#LIGHTWEIGHT_RECOGNITION) - [1. Lightweight Chinese Recognition Model Inference](#LIGHTWEIGHT_RECOGNITION)
- [2. Multilingaul Model Inference](#MULTILINGUAL_MODEL_INFERENCE) - [2. Multilingual Model Inference](#MULTILINGUAL_MODEL_INFERENCE)
- [Angle Classification Model Inference](#ANGLE_CLASS_MODEL_INFERENCE) - [Angle Classification Model Inference](#ANGLE_CLASS_MODEL_INFERENCE)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#CONCATENATION) - [Text Detection Angle Classification and Recognition Inference Concatenation](#CONCATENATION)
...@@ -25,7 +25,7 @@ tar xf ch_PP-OCRv2_det_infer.tar ...@@ -25,7 +25,7 @@ tar xf ch_PP-OCRv2_det_infer.tar
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/" python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/"
``` ```
The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows: The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
![](../imgs_results/det_res_00018069.jpg) ![](../imgs_results/det_res_00018069.jpg)
...@@ -75,7 +75,7 @@ Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658) ...@@ -75,7 +75,7 @@ Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658)
<a name="MULTILINGUAL_MODEL_INFERENCE"></a> <a name="MULTILINGUAL_MODEL_INFERENCE"></a>
### 2. Multilingaul Model Inference ### 2. Multilingual Model Inference
If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results, If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition: You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
......
## QUICK INSTALLATION ## QUICK INSTALLATION
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility. After testing, PaddleOCR can run on glibc 2.23. You can also test other glibc versions or install glibc 2.23 for the best compatibility.
PaddleOCR working environment: PaddleOCR working environment:
- PaddlePaddle 2.0.0 - PaddlePaddle 2.0.0
- python3.7 - Python 3.7
- glibc 2.23 - glibc 2.23
It is recommended to use the docker provided by us to run PaddleOCR, please refer to the use of docker [link](https://www.runoob.com/docker/docker-tutorial.html/). It is recommended to use the docker provided by us to run PaddleOCR. Please refer to the docker tutorial [link](https://www.runoob.com/docker/docker-tutorial.html/).
*If you want to directly run the prediction code on mac or windows, you can start from step 2.* *If you want to directly run the prediction code on Mac or Windows, you can start from step 2.*
**1. (Recommended) Prepare a docker environment. The first time you use this docker image, it will be downloaded automatically. Please be patient.** **1. (Recommended) Prepare a docker environment. For the first time you use this docker image, it will be downloaded automatically. Please be patient.**
``` ```
# Switch to the working directory # Switch to the working directory
cd /home/Projects cd /home/Projects
...@@ -22,7 +22,7 @@ cd /home/Projects ...@@ -22,7 +22,7 @@ cd /home/Projects
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash sudo docker run --name ppocr -v $PWD:/paddle --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
``` ```
If using CUDA10, please run the following command to create a container. With CUDA10, please run the following command to create a container.
It is recommended to set a shared memory greater than or equal to 32G through the --shm-size parameter: It is recommended to set a shared memory greater than or equal to 32G through the --shm-size parameter:
``` ```
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
...@@ -51,11 +51,11 @@ For more software version requirements, please refer to the instructions in [Ins ...@@ -51,11 +51,11 @@ For more software version requirements, please refer to the instructions in [Ins
# Recommend # Recommend
git clone https://github.com/PaddlePaddle/PaddleOCR git clone https://github.com/PaddlePaddle/PaddleOCR
# If you cannot pull successfully due to network problems, you can also choose to use the code hosting on the cloud: # If you cannot pull successfully due to network problems, you can switch to the mirror hosted on Gitee:
git clone https://gitee.com/paddlepaddle/PaddleOCR git clone https://gitee.com/paddlepaddle/PaddleOCR
# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method. # Note: The mirror on Gitee may not keep in synchronization with the latest update with the project on GitHub. There might be a delay of 3-5 days. Please try GitHub at first.
``` ```
**4. Install third-party libraries** **4. Install third-party libraries**
...@@ -66,6 +66,6 @@ pip3 install -r requirements.txt ...@@ -66,6 +66,6 @@ pip3 install -r requirements.txt
If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows. If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows.
Please try to download Shapely whl file using [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely). Please try to download Shapely whl file from [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found) Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
<a name="0"></a>
# Knowledge Distillation
+ [Knowledge Distillation](#0)
+ [1. Introduction](#1)
- [1.1 Introduction to Knowledge Distillation](#11)
- [1.2 Introduction to PaddleOCR Knowledge Distillation](#12)
+ [2. Configuration File Analysis](#2)
+ [2.1 Recognition Model Configuration File Analysis](#21)
- [2.1.1 Model Structure](#211)
- [2.1.2 Loss Function ](#212)
- [2.1.3 Post-processing](#213)
- [2.1.4 Metric Calculation](#214)
- [2.1.5 Fine-tuning Distillation Model](#215)
+ [2.2 Detection Model Configuration File Analysis](#22)
- [2.2.1 Model Structure](#221)
- [2.2.2 Loss Function](#222)
- [2.2.3 Post-processing](#223)
- [2.2.4 Metric Calculation](#224)
- [2.2.5 Fine-tuning Distillation Model](#225)
<a name="1"></a>
## 1. Introduction
<a name="11"></a>
### 1.1 Introduction to Knowledge Distillation
In recent years, deep neural networks have been proved to be an extremely effective method for solving problems in the fields of computer vision and natural language processing.
By constructing a suitable neural network and training it, the performance metrics of the final network model will basically exceed the traditional algorithm.
When the amount of data is large enough, increasing the amount of parameters by constructing a reasonable network model can significantly improve the performance of the model,
but this brings about the problem of a sharp increase in the complexity of the model. Large models are more expensive to use in actual scenarios.
Deep neural networks generally have more parameter redundancy. At present, there are several main methods to compress the model and reduce the amount of its parameters.
Such as pruning, quantification, knowledge distillation, etc., where knowledge distillation refers to the use of teacher models to guide student models to learn specific tasks,
to ensure that the small model obtains a relatively large performance improvement under the condition of unchanged parameters.
In addition, in the knowledge distillation task, a mutual learning model training method was also derived.
The paper [Deep Mutual Learning](https://arxiv.org/abs/1706.00384) pointed out that using two identical models to supervise each other during the training process can achieve better results than a single model training.
<a name="12"></a>
### 1.2 Introduction to PaddleOCR Knowledge Distillation
Whether it is a large model distilling a small model, or a small model learning from each other and updating parameters,
they are essentially the output between different models or mutual supervision between feature maps.
The only difference is (1) whether the model requires fixed parameters. (2) Whether the model needs to be loaded with a pre-trained model.
For the case where a large model distills a small model, the large model generally needs to load the pre-trained model and fix the parameters.
For the situation where small models distill each other, the small models generally do not load the pre-trained model, and the parameters are also in a learnable state.
In the task of knowledge distillation, it is not only the distillation between two models, but also the situation where multiple models learn from each other.
Therefore, in the knowledge distillation code framework, it is also necessary to support this type of distillation method.
The algorithm of knowledge distillation is integrated in PaddleOCR. Specifically, it has the following main features:
- It supports mutual learning of any network, and does not require the sub-network structure to be completely consistent or to have a pre-trained model. At the same time, there is no limit to the number of sub-networks, just add it in the configuration file.
- Support arbitrarily configuring the loss function through the configuration file, not only can use a certain loss, but also a combination of multiple losses.
- Support all model-related environments such as knowledge distillation training, prediction, evaluation, and export, which is convenient for use and deployment.
Through knowledge distillation, in the common Chinese and English text recognition task, without adding any time-consuming prediction,
the accuracy of the model can be improved by more than 3%. Combining the learning rate adjustment strategy and the model structure fine-tuning strategy,
the final improvement is more than 5%.
<a name="2"></a>
## 2. Configuration File Analysis
In the process of knowledge distillation training, there is no change in data preprocessing, optimizer, learning rate, and some global attributes.
The configuration files of the model structure, loss function, post-processing, metric calculation and other modules need to be fine-tuned.
The following takes the knowledge distillation configuration file for recognition and detection as an example to analyze the training and configuration of knowledge distillation.
<a name="21"></a>
### 2.1 Recognition Model Configuration File Analysis
The configuration file is in [ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml).
<a name="211"></a>
#### 2.1.1 Model Structure
In the knowledge distillation task, the model structure configuration is as follows.
```yaml
Architecture:
model_type: &model_type "rec" # Model category, recognition, detection, etc.
name: DistillationModel # Structure name, in the distillation task, it is DistillationModel
algorithm: Distillation # Algorithm name
Models: # Model, including the configuration information of the subnet
Teacher: # The name of the subnet, it must include at least the `pretrained` and `freeze_params` parameters, and the other parameters are the construction parameters of the subnet
pretrained: # Does this sub-network need to load pre-training weights
freeze_params: false # Do you need fixed parameters
return_all_feats: true # Do you need to return all features, if it is False, only the final output is returned
model_type: *model_type # Model category
algorithm: CRNN # The algorithm name of the sub-network. The remaining parameters of the sub-network are consistent with the general model training configuration
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student: # Another sub-network, here is a distillation example of DML, the two sub-networks have the same structure, and both need to learn parameters
pretrained: # The following parameters are the same as above
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
If you want to add more sub-networks for training, you can also add the corresponding fields in the configuration file according to the way of adding `Student` and `Teacher`.
For example, if you want 3 models to supervise each other and train together, then `Architecture` can be written in the following format.
```yaml
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student2: # The new sub-network introduced in the knowledge distillation task, the configuration is the same as above
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
When the model is finally trained, it contains 3 sub-networks: `Teacher`, `Student`, `Student2`.
The specific implementation code of the `DistillationModel` class can refer to [distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py).
The final model output is a dictionary, the key is the name of all the sub-networks, for example, here are `Student` and `Teacher`, and the value is the output of the corresponding sub-network,
which can be `Tensor` (only the last layer of the network is returned) and `dict` (also returns the characteristic information in the middle).
In the recognition task, in order to add more loss functions and ensure the scalability of the distillation method, the output of each sub-network is saved as a `dict`, which contains the sub-module output.
Take the recognition model as an example. The output result of each sub-network is `dict`, the key contains `backbone_out`, `neck_out`, `head_out`, and `value` is the tensor of the corresponding module. Finally, for the above configuration file, `DistillationModel` The output format is as follows.
```json
{
"Teacher": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
},
"Student": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
}
}
```
<a name="212"></a>
#### 2.1.2 Loss Function
In the knowledge distillation task, the loss function configuration is as follows.
```yaml
Loss:
name: CombinedLoss # Loss function name
loss_config_list: # List of loss function configuration files, mandatory functions for CombinedLoss
- DistillationCTCLoss: # CTC loss function based on distillation, inherited from standard CTC loss
weight: 1.0 # The weight of the loss function. In loss_config_list, each loss function must include this field
model_name_list: ["Student", "Teacher"] # For the prediction results of the distillation model, extract the output of these two sub-networks and calculate the CTC loss with gt
key: head_out # In the sub-network output dict, take the corresponding tensor
- DistillationDMLLoss: # DML loss function, inherited from the standard DMLLoss
weight: 1.0
act: "softmax" # Activation function, use it to process the input, can be softmax, sigmoid or None, the default is None
model_name_pairs: # The subnet name pair used to calculate DML loss. If you want to calculate the DML loss of other subnets, you can continue to add it below the list
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss: # Distilled distance loss function
weight: 1.0
mode: "l2" # Support l1, l2 or smooth_l1
model_name_pairs: # Calculate the distance loss of the subnet name pair
- ["Student", "Teacher"]
key: backbone_out
```
Among the above loss functions, all distillation loss functions are inherited from the standard loss function class.
The main functions are: Analyze the output of the distillation model, find the intermediate node (tensor) used to calculate the loss,
and then use the standard loss function class to calculate.
Taking the above configuration as an example, the final distillation training loss function contains the following three parts.
- The final output `head_out` of `Student` and `Teacher` calculates the CTC loss with gt (loss weight equals 1.0). Here, because both sub-networks need to update the parameters, both of them need to calculate the loss with gt.
- DML loss between `Student` and `Teacher`'s final output `head_out` (loss weight equals 1.0).
- L2 loss between `Student` and `Teacher`'s backbone network output `backbone_out` (loss weight equals 1.0).
For more specific implementation of `CombinedLoss`, please refer to: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23).
For more specific implementations of distillation loss functions such as `DistillationCTCLoss`, please refer to [distillation_loss.py](../../ppocr/losses/distillation_loss.py)
<a name="213"></a>
#### 2.1.3 Post-processing
In the knowledge distillation task, the post-processing configuration is as follows.
```yaml
PostProcess:
name: DistillationCTCLabelDecode # CTC decoding post-processing of distillation tasks, inherited from the standard CTCLabelDecode class
model_name: ["Student", "Teacher"] # For the prediction results of the distillation model, extract the outputs of these two sub-networks and decode them
key: head_out # Take the corresponding tensor in the subnet output dict
```
Taking the above configuration as an example, the CTC decoding output of the two sub-networks `Student` and `Teahcer` will be calculated at the same time.
Among them, `key` is the name of the subnet, and `value` is the list of subnets.
For more specific implementation of `DistillationCTCLabelDecode`, please refer to: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
<a name="214"></a>
#### 2.1.4 Metric Calculation
In the knowledge distillation task, the metric calculation configuration is as follows.
```yaml
Metric:
name: DistillationMetric # CTC decoding post-processing of distillation tasks, inherited from the standard CTCLabelDecode class
base_metric_name: RecMetric # The base class of indicator calculation. For the output of the model, the indicator will be calculated based on this class
main_indicator: acc # The name of the indicator
key: "Student" # Select the main_indicator of this subnet as the criterion for saving the best model
```
Taking the above configuration as an example, the accuracy metric of the `Student` subnet will be used as the judgment metric for saving the best model.
At the same time, the accuracy metric of all subnets will be printed out in the log.
For more specific implementation of `DistillationMetric`, please refer to: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24).
<a name="215"></a>
#### 2.1.5 Fine-tuning Distillation Model
There are two ways to fine-tune the recognition distillation task.
1. Fine-tuning based on knowledge distillation: this situation is relatively simple, download the pre-trained model. Then configure the pre-training model path and your own data path in [ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml) to perform fine-tuning training of the model.
2. Do not use knowledge distillation in fine-tuning: In this case, you need to first extract the student model parameters from the pre-training model. The specific steps are as follows.
- First download the pre-trained model and unzip it.
```shell
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
tar -xf ch_PP-OCRv2_rec_train.tar
```
- Then use python to extract the student model parameters
```python
import paddle
# Load the pre-trained model
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
# View the keys of the weight parameter
print(all_params.keys())
# Weight extraction of student model
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# View the keys of the weight parameters of the student model
print(s_params.keys())
# Save weight parameters
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
```
After the extraction is complete, use [ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml) to modify the path of the pre-trained model (the path of the exported `student.pdparams` model) and your own data path to fine-tune the model.
<a name="22"></a>
### 2.2 Detection Model Configuration File Analysis
The configuration file of the detection model distillation is in the ```PaddleOCR/configs/det/ch_PP-OCRv2/``` directory, which contains three distillation configuration files:
- ```ch_PP-OCRv2_det_cml.yml```, Use one large model to distill two small models, and the two small models learn from each other
- ```ch_PP-OCRv2_det_dml.yml```, Method of mutual distillation of two student models
- ```ch_PP-OCRv2_det_distill.yml```, The method of using large teacher model to distill small student model
<a name="221"></a>
#### 2.2.1 Model Structure
In the knowledge distillation task, the model structure configuration is as follows:
```
Architecture:
name: DistillationModel # Structure name, in the distillation task, it is DistillationModel
algorithm: Distillation # Algorithm name
Models: # Model, including the configuration information of the subnet
Student: # The name of the subnet, it must include at least the `pretrained` and `freeze_params` parameters, and the other parameters are the construction parameters of the subnet
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained # Does this sub-network need to load pre-training weights
freeze_params: false # Do you need fixed parameters
return_all_feats: false # Do you need to return all features, if it is False, only the final output is returned
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher: # Another sub-network, here is a distillation example of a large model distill a small model
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # The Teacher model is well-trained and does not need to participate in training
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
```
If DML is used, that is, the method of two small models learning from each other, the Teacher network structure in the above configuration file needs to be set to the same configuration as the Student model.
Refer to the configuration file for details. [ch_PP-OCRv2_det_dml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_dml.yml)
The following describes the configuration file parameters [ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml):
```
Architecture:
name: DistillationModel
algorithm: Distillation
model_type: det
Models:
Teacher: # Teacher model configuration of CML distillation
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true # Teacher does not train
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Student: # Student model configuration for CML distillation
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2: # Student2 model configuration for CML distillation
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
```
The specific implementation code of the distillation model `DistillationModel` class can refer to [distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py).
The final model output is a dictionary, the key is the name of all the sub-networks, for example, here are `Student` and `Teacher`, and the value is the output of the corresponding sub-network,
which can be `Tensor` (only the last layer of the network is returned) and `dict` (also returns the characteristic information in the middle).
In the distillation task, in order to facilitate the addition of the distillation loss function, the output of each network is saved as a `dict`, which contains the sub-module output.
The key contains `backbone_out`, `neck_out`, `head_out`, and `value` is the tensor of the corresponding module. Finally, for the above configuration file, the output format of `DistillationModel` is as follows.
```json
{
"Teacher": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
},
"Student": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
}
}
```
<a name="222"></a>
#### 2.2.2 Loss Function
In the task of detection knowledge distillation ```ch_PP-OCRv2_det_distill.yml````, the distillation loss function configuration is as follows.
```yaml
Loss:
name: CombinedLoss # Loss function name
loss_config_list: # List of loss function configuration files, mandatory functions for CombinedLoss
- DistillationDilaDBLoss: # DB loss function based on distillation, inherited from standard DBloss
weight: 1.0 # The weight of the loss function. In loss_config_list, each loss function must include this field
model_name_pairs: # Extract the output of these two sub-networks and calculate the loss between them
- ["Student", "Teacher"]
key: maps # In the sub-network output dict, take the corresponding tensor
balance_loss: true # The following parameters are the configuration parameters of standard DBloss
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss: # Used to calculate the loss between Student and GT
weight: 1.0
model_name_list: ["Student"] # The model name only has Student, which means that the loss between Student and GT is calculated
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
```
Similarly, distillation loss function configuration(`ch_PP-OCRv2_det_cml.yml`) is shown below. Compared with the loss function configuration of ch_PP-OCRv2_det_distill.yml, there are three changes:
```yaml
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"] # 1. Calculate the loss of two Student and Teacher
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss: # 2. Add to calculate the loss between two students
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"] # 3. Calculate the loss between two students and GT
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
```
For more specific implementation of `DistillationDilaDBLoss`, please refer to: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/distillation_loss.py#L185).
For more specific implementations of distillation loss functions such as `DistillationDBLoss`, please refer to: [distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/04c44974b13163450dfb6bd2c327863f8a194b3c/ppocr/losses/distillation_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L148)
<a name="223"></a>
#### 2.2.3 Post-processing
In the task of detecting knowledge distillation, the post-processing configuration of detecting distillation is as follows.
```yaml
PostProcess:
name: DistillationDBPostProcess # The CTC decoding post-processing of the DB detection distillation task, inherited from the standard DBPostProcess class
model_name: ["Student", "Student2", "Teacher"] # Extract the output of multiple sub-networks and decode them. The network that does not require post-processing is not set in model_name
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
```
Taking the above configuration as an example, the output of the three subnets `Student`, `Student2` and `Teacher` will be calculated at the same time for post-processing calculations.
Since there are multiple inputs, there are also multiple outputs returned by post-processing.
For a more specific implementation of `DistillationDBPostProcess`, please refer to: [db_postprocess.py](../../ppocr/postprocess/db_postprocess.py#L195)
<a name="224"></a>
#### 2.2.4 Metric Calculation
In the knowledge distillation task, the metric calculation configuration is as follows.
```yaml
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
```
Since distillation needs to include multiple networks, only one network metrics needs to be calculated when calculating the metrics.
The `key` field is set to `Student`, it means that only the metrics of the `Student` network is calculated.
Model Structure
<a name="225"></a>
#### 2.2.5 Fine-tuning Distillation Model
There are three ways to fine-tune the detection distillation task:
- `ch_PP-OCRv2_det_distill.yml`, The teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv2_det_cml.yml`, Use cml distillation. Similarly, the Teacher model is set to the model provided by PaddleOCR or the large model you have trained.
- `ch_PP-OCRv2_det_dml.yml`, Distillation using DML. The method of mutual distillation of the two Student models has an accuracy improvement of about 1.7% on the data set used by PaddleOCR.
In fine-tune, you need to set the pre-trained model to be loaded in the `pretrained` parameter of the network structure.
In terms of accuracy improvement, `cml` > `dml` > `distill`. When the amount of data is insufficient or the accuracy of the teacher model is similar to that of the student, this conclusion may change.
In addition, since the distillation pre-training model provided by PaddleOCR contains multiple model parameters, if you want to extract the parameters of the student model, you can refer to the following code:
```sh
# Download the parameters of the distillation training model
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar
```
```python
import paddle
# Load the pre-trained model
all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams")
# View the keys of the weight parameter
print(all_params.keys())
# Extract the weights of the student model
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# View the keys of the weight parameters of the student model
print(s_params.keys())
# Save
paddle.save(s_params, "ch_PP-OCRv2_det_distill_train/student.pdparams")
```
Finally, the parameters of the student model will be saved in `ch_PP-OCRv2_det_distill_train/student.pdparams` for the fine-tune of the model.
...@@ -7,13 +7,13 @@ This section contains two parts. Firstly, [PP-OCR Model Download](./models_list_ ...@@ -7,13 +7,13 @@ This section contains two parts. Firstly, [PP-OCR Model Download](./models_list_
Let's first understand some basic concepts. Let's first understand some basic concepts.
- [Introduction about OCR](#introduction-about-ocr) - [Introduction to OCR](#introduction-to-ocr)
* [Basic Concepts of OCR Detection Model](#basic-concepts-of-ocr-detection-model) * [Basic Concepts of OCR Detection Model](#basic-concepts-of-ocr-detection-model)
* [Basic Concepts of OCR Recognition Model](#basic-concepts-of-ocr-recognition-model) * [Basic Concepts of OCR Recognition Model](#basic-concepts-of-ocr-recognition-model)
* [PP-OCR Model](#pp-ocr-model) * [PP-OCR Model](#pp-ocr-model)
## 1. Introduction about OCR ## 1. Introduction to OCR
This section briefly introduces the basic concepts of OCR detection model and recognition model, and introduces PaddleOCR's PP-OCR model. This section briefly introduces the basic concepts of OCR detection model and recognition model, and introduces PaddleOCR's PP-OCR model.
......
# OCR Model List(V2.1, updated on 2021.9.6) # OCR Model List(V2.1, updated on 2021.9.6)
> **Note** > **Note**
> 1. Compared with the model v2.0, the 2.1 version of the detection model has a improvement in accuracy, and the 2.1 version of the recognition model is optimized in accuracy and CPU speed. > 1. Compared with the model v2.0, the 2.1 version of the detection model has a improvement in accuracy, and the 2.1 version of the recognition model has optimizations in accuracy and speed with CPU.
> 2. Compared with [models 1.1](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md), which are trained with static graph programming paradigm, models 2.0 are the dynamic graph trained version and achieve close performance. > 2. Compared with [models 1.1](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md), which are trained with static graph programming paradigm, models 2.0 are the dynamic graph trained version and achieve close performance.
> 3. All models in this tutorial are all ppocr-series models, for more introduction of algorithms and models based on public dataset, you can refer to [algorithm overview tutorial](./algorithm_overview_en.md). > 3. All models in this tutorial are all ppocr-series models, for more introduction of algorithms and models based on public dataset, you can refer to [algorithm overview tutorial](./algorithm_overview_en.md).
...@@ -18,7 +18,7 @@ The downloadable models provided by PaddleOCR include `inference model`, `traine ...@@ -18,7 +18,7 @@ The downloadable models provided by PaddleOCR include `inference model`, `traine
|--- | --- | --- | |--- | --- | --- |
|inference model|inference.pdmodel、inference.pdiparams|Used for inference based on Paddle inference engine,[detail](./inference_en.md)| |inference model|inference.pdmodel、inference.pdiparams|Used for inference based on Paddle inference engine,[detail](./inference_en.md)|
|trained model, pre-trained model|\*.pdparams、\*.pdopt、\*.states |The checkpoints model saved in the training process, which stores the parameters of the model, mostly used for model evaluation and continuous training.| |trained model, pre-trained model|\*.pdparams、\*.pdopt、\*.states |The checkpoints model saved in the training process, which stores the parameters of the model, mostly used for model evaluation and continuous training.|
|slim model|\*.nb| Model compressed by PaddleSim (a model compression tool using PaddlePaddle), which is suitable for mobile-side deployment scenarios (Paddle-Lite is needed for slim model deployment). | |slim model|\*.nb| Model compressed by PaddleSlim (a model compression tool using PaddlePaddle), which is suitable for mobile-side deployment scenarios (Paddle-Lite is needed for slim model deployment). |
Relationship of the above models is as follows. Relationship of the above models is as follows.
...@@ -50,7 +50,7 @@ Relationship of the above models is as follows. ...@@ -50,7 +50,7 @@ Relationship of the above models is as follows.
|ch_ppocr_server_v2.0_rec|General model, supporting Chinese, English and number recognition|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) | |ch_ppocr_server_v2.0_rec|General model, supporting Chinese, English and number recognition|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) |
**Note:** The `trained model` is finetuned on the `pre-trained model` with real data and synthsized vertical text data, which achieved better performance in real scene. The `pre-trained model` is directly trained on the full amount of real data and synthsized data, which is more suitable for finetune on your own dataset. **Note:** The `trained model` is fine-tuned on the `pre-trained model` with real data and synthesized vertical text data, which achieved better performance in real scene. The `pre-trained model` is directly trained on the full amount of real data and synthesized data, which is more suitable for fine-tune on your own dataset.
<a name="English"></a> <a name="English"></a>
### 2.2 English Recognition Model ### 2.2 English Recognition Model
......
...@@ -28,12 +28,12 @@ The multilingual models cover Latin, Arabic, Traditional Chinese, Korean, Japane ...@@ -28,12 +28,12 @@ The multilingual models cover Latin, Arabic, Traditional Chinese, Korean, Japane
This document will briefly introduce how to use the multilingual model. This document will briefly introduce how to use the multilingual model.
- [1 Installation](#Install) - [1 Installation](#Install)
- [1.1 paddle installation](#paddleinstallation) - [1.1 Paddle installation](#paddleinstallation)
- [1.2 paddleocr package installation](#paddleocr_package_install) - [1.2 PaddleOCR package installation](#paddleocr_package_install)
- [2 Quick Use](#Quick_Use) - [2 Quick Use](#Quick_Use)
- [2.1 Command line operation](#Command_line_operation) - [2.1 Command line operation](#Command_line_operation)
- [2.2 python script running](#python_Script_running) - [2.2 Run with Python script](#python_Script_running)
- [3 Custom Training](#Custom_Training) - [3 Custom Training](#Custom_Training)
- [4 Inference and Deployment](#inference) - [4 Inference and Deployment](#inference)
- [4 Supported languages and abbreviations](#language_abbreviations) - [4 Supported languages and abbreviations](#language_abbreviations)
...@@ -42,7 +42,7 @@ This document will briefly introduce how to use the multilingual model. ...@@ -42,7 +42,7 @@ This document will briefly introduce how to use the multilingual model.
## 1 Installation ## 1 Installation
<a name="paddle_install"></a> <a name="paddle_install"></a>
### 1.1 paddle installation ### 1.1 Paddle installation
``` ```
# cpu # cpu
pip install paddlepaddle pip install paddlepaddle
...@@ -52,7 +52,7 @@ pip install paddlepaddle-gpu ...@@ -52,7 +52,7 @@ pip install paddlepaddle-gpu
``` ```
<a name="paddleocr_package_install"></a> <a name="paddleocr_package_install"></a>
### 1.2 paddleocr package installation ### 1.2 PaddleOCR package installation
pip install pip install
...@@ -79,8 +79,8 @@ paddleocr -h ...@@ -79,8 +79,8 @@ paddleocr -h
* Whole image prediction (detection + recognition) * Whole image prediction (detection + recognition)
Paddleocr currently supports 80 languages, which can be switched by modifying the --lang parameter. PaddleOCR currently supports 80 languages, which can be specified by the --lang parameter.
The specific supported [language] (#language_abbreviations) can be viewed in the table. The supported languages are listed in the [table](#language_abbreviations).
``` bash ``` bash
paddleocr --image_dir doc/imgs_en/254.jpg --lang=en paddleocr --image_dir doc/imgs_en/254.jpg --lang=en
...@@ -90,7 +90,7 @@ paddleocr --image_dir doc/imgs_en/254.jpg --lang=en ...@@ -90,7 +90,7 @@ paddleocr --image_dir doc/imgs_en/254.jpg --lang=en
<img src="../imgs_results/multi_lang/img_02.jpg" width="600" height="600"> <img src="../imgs_results/multi_lang/img_02.jpg" width="600" height="600">
</div> </div>
The result is a list, each item contains a text box, text and recognition confidence The result is a list. Each item contains a text box, text and recognition confidence
```text ```text
[('PHO CAPITAL', 0.95723116), [[66.0, 50.0], [327.0, 44.0], [327.0, 76.0], [67.0, 82.0]]] [('PHO CAPITAL', 0.95723116), [[66.0, 50.0], [327.0, 44.0], [327.0, 76.0], [67.0, 82.0]]]
[('107 State Street', 0.96311164), [[72.0, 90.0], [451.0, 84.0], [452.0, 116.0], [73.0, 121.0]]] [('107 State Street', 0.96311164), [[72.0, 90.0], [451.0, 84.0], [452.0, 116.0], [73.0, 121.0]]]
...@@ -110,7 +110,7 @@ paddleocr --image_dir doc/imgs_words_en/word_308.png --det false --lang=en ...@@ -110,7 +110,7 @@ paddleocr --image_dir doc/imgs_words_en/word_308.png --det false --lang=en
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_words_en/word_308.png) ![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_words_en/word_308.png)
The result is a tuple, which returns the recognition result and recognition confidence The result is a 2-tuple, which contains the recognition result and recognition confidence
```text ```text
(0.99879867, 'LITTLE') (0.99879867, 'LITTLE')
...@@ -122,7 +122,7 @@ The result is a tuple, which returns the recognition result and recognition conf ...@@ -122,7 +122,7 @@ The result is a tuple, which returns the recognition result and recognition conf
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
``` ```
The result is a list, each item contains only text boxes The result is a list. Each item represents the coordinates of a text box.
``` ```
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]] [[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
...@@ -132,9 +132,9 @@ The result is a list, each item contains only text boxes ...@@ -132,9 +132,9 @@ The result is a list, each item contains only text boxes
``` ```
<a name="python_script_running"></a> <a name="python_script_running"></a>
### 2.2 python script running ### 2.2 Run with Python script
ppocr also supports running in python scripts for easy embedding in your own code: PPOCR is able to run with Python scripts for easy integration with your own code:
* Whole image prediction (detection + recognition) * Whole image prediction (detection + recognition)
...@@ -167,12 +167,12 @@ Visualization of results: ...@@ -167,12 +167,12 @@ Visualization of results:
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/korean.jpg) ![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/korean.jpg)
ppocr also supports direction classification. For more usage methods, please refer to: [whl package instructions](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md). PPOCR also supports direction classification. For more detailed usage, please refer to: [whl package instructions](whl_en.md).
<a name="Custom_training"></a> <a name="Custom_training"></a>
## 3 Custom training ## 3 Custom training
ppocr supports using your own data for custom training or finetune, where the recognition model can refer to [French configuration file](../../configs/rec/multi_language/rec_french_lite_train.yml) PPOCR supports using your own data for custom training or fine-tune, where the recognition model can refer to [French configuration file](../../configs/rec/multi_language/rec_french_lite_train.yml)
Modify the training data path, dictionary and other parameters. Modify the training data path, dictionary and other parameters.
For specific data preparation and training process, please refer to: [Text Detection](../doc_en/detection_en.md), [Text Recognition](../doc_en/recognition_en.md), more functions such as predictive deployment, For specific data preparation and training process, please refer to: [Text Detection](../doc_en/detection_en.md), [Text Recognition](../doc_en/recognition_en.md), more functions such as predictive deployment,
...@@ -183,7 +183,7 @@ For functions such as data annotation, you can read the complete [Document Tutor ...@@ -183,7 +183,7 @@ For functions such as data annotation, you can read the complete [Document Tutor
## 4 Inference and Deployment ## 4 Inference and Deployment
In addition to installing the whl package for quick forecasting, In addition to installing the whl package for quick forecasting,
ppocr also provides a variety of forecasting deployment methods. PPOCR also provides a variety of forecasting deployment methods.
If necessary, you can read related documents: If necessary, you can read related documents:
- [Python Inference](./inference_en.md) - [Python Inference](./inference_en.md)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
## 1. PaddleOCR Overview ## 1. PaddleOCR Overview
PaddleOCR contains rich text detection, text recognition and end-to-end algorithms. Combining actual testing and industrial experience, PaddleOCR chooses DB and CRNN as the basic detection and recognition models, and proposes a series of models, named PP-OCR, for industrial applications after a series of optimization strategies. The PP-OCR model is aimed at general scenarios and forms a model library according to different languages. Based on the capabilities of PP-OCR, PaddleOCR releases the PP-Structure tool library for document scene tasks, including two major tasks: layout analysis and table recognition. In order to get through the entire process of industrial landing, PaddleOCR provides large-scale data production tools and a variety of prediction deployment tools to help developers quickly turn ideas into reality. PaddleOCR contains rich text detection, text recognition and end-to-end algorithms. With the experience from real world scenarios and the industry, PaddleOCR chooses DB and CRNN as the basic detection and recognition models, and proposes a series of models, named PP-OCR, for industrial applications after a series of optimization strategies. The PP-OCR model is aimed at general scenarios and forms a model library of different languages. Based on the capabilities of PP-OCR, PaddleOCR releases the PP-Structure toolkit for document scene tasks, including two major tasks: layout analysis and table recognition. In order to get through the entire process of industrial landing, PaddleOCR provides large-scale data production tools and a variety of prediction deployment tools to help developers quickly turn ideas into reality.
<div align="center"> <div align="center">
<img src="../overview_en.png"> <img src="../overview_en.png">
...@@ -18,11 +18,11 @@ PaddleOCR contains rich text detection, text recognition and end-to-end algorith ...@@ -18,11 +18,11 @@ PaddleOCR contains rich text detection, text recognition and end-to-end algorith
# Recommend # Recommend
git clone https://github.com/PaddlePaddle/PaddleOCR git clone https://github.com/PaddlePaddle/PaddleOCR
# If you cannot pull successfully due to network problems, you can also choose to use the code hosting on the cloud: # If you cannot pull successfully due to network problems, you can switch to the mirror hosted on Gitee:
git clone https://gitee.com/paddlepaddle/PaddleOCR git clone https://gitee.com/paddlepaddle/PaddleOCR
# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method. # Note: The mirror on Gitee may not keep in synchronization with the latest project on GitHub. There might be a delay of 3-5 days. Please try GitHub at first.
``` ```
### **2.2 Install third-party libraries** ### **2.2 Install third-party libraries**
...@@ -34,6 +34,6 @@ pip3 install -r requirements.txt ...@@ -34,6 +34,6 @@ pip3 install -r requirements.txt
If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows. If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows.
Please try to download Shapely whl file using [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely). Please try to download Shapely whl file from [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found) Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
\ No newline at end of file
...@@ -6,18 +6,18 @@ ...@@ -6,18 +6,18 @@
<a name="Brief_Introduction"></a> <a name="Brief_Introduction"></a>
## 1. Brief Introduction ## 1. Brief Introduction
OCR algorithm can be divided into two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm gets the detection box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm can complete text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster. OCR algorithms can be divided into two categories: two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm locates the box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm combines text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster.
### Introduction Of PGNet Algorithm ### Introduction Of PGNet Algorithm
In recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has the advantages that other algorithms do not During the recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has some advantages over the other algorithms.
- Pgnet loss is designed to guide training, and no character-level annotations is needed - PGNet loss is designed to guide training, and no character-level annotations is needed.
- NMS and ROI related operations are not needed, It can accelerate the prediction - NMS and ROI related operations are not needed. It can accelerate the prediction
- The reading order prediction module is proposed - The reading order prediction module is proposed
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition - A graph based modification module (GRM) is proposed to further improve the performance of model recognition
- Higher accuracy and faster prediction speed - Higher accuracy and faster prediction speed
For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows: For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf). The schematic diagram of the algorithm is as follows:
![](../pgnet_framework.png) ![](../pgnet_framework.png)
After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction. After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text center-line prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition. The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
The results of detection and recognition are as follows: The results of detection and recognition are as follows:
...@@ -40,7 +40,7 @@ Please refer to [Operation Environment Preparation](./environment_en.md) to conf ...@@ -40,7 +40,7 @@ Please refer to [Operation Environment Preparation](./environment_en.md) to conf
<a name="Quick_Use"></a> <a name="Quick_Use"></a>
## 3. Quick Use ## 3. Quick Use
### inference model download ### Inference model download
This section takes the trained end-to-end model as an example to quickly use the model prediction. First, download the trained end-to-end inference model [download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar) This section takes the trained end-to-end model as an example to quickly use the model prediction. First, download the trained end-to-end inference model [download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
``` ```
mkdir inference && cd inference mkdir inference && cd inference
...@@ -131,7 +131,7 @@ python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0 ...@@ -131,7 +131,7 @@ python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0
``` ```
#### Load trained model and continue training #### Load trained model and continue training
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded. If you would like to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
```shell ```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
``` ```
......
...@@ -12,15 +12,15 @@ ...@@ -12,15 +12,15 @@
* [4. FAQ](#3-faq) * [4. FAQ](#3-faq)
This article will introduce the basic concepts that need to be mastered during model training and the tuning methods during training. This article will introduce the basic concepts that is necessary for model training and tuning.
At the same time, it will briefly introduce the components of the PaddleOCR model training data and how to prepare the data finetune model in the vertical scene. At the same time, it will briefly introduce the structure of the training data and how to prepare the data to fine-tune model in vertical scenes.
<a name="1-Yml-Configuration"></a> <a name="1-Yml-Configuration"></a>
## 1. Yml Configuration ## 1. Yml Configuration
The PaddleOCR model uses configuration files to manage network training and evaluation parameters. In the configuration file, you can set the model, optimizer, loss function, and pre- and post-processing parameters of the model. PaddleOCR reads these parameters from the configuration file, and then builds a complete training process to complete the model training. When optimized, the configuration can be completed by modifying the parameters in the configuration file, which is simple to use and convenient to modify. The PaddleOCR uses configuration files to control network training and evaluation parameters. In the configuration file, you can set the model, optimizer, loss function, and pre- and post-processing parameters of the model. PaddleOCR reads these parameters from the configuration file, and then builds a complete training process to train the model. Fine-tuning can also be completed by modifying the parameters in the configuration file, which is simple and convenient.
For the complete configuration file description, please refer to [Configuration File](./config_en.md) For the complete configuration file description, please refer to [Configuration File](./config_en.md)
...@@ -28,13 +28,13 @@ For the complete configuration file description, please refer to [Configuration ...@@ -28,13 +28,13 @@ For the complete configuration file description, please refer to [Configuration
## 2. Basic Concepts ## 2. Basic Concepts
In the process of model training, some hyperparameters need to be manually adjusted to help the model obtain the optimal index at the least loss. Different data volumes may require different hyper-parameters. When you want to finetune your own data or tune the model effect, there are several parameter adjustment strategies for reference: During the model training process, some hyper-parameters can be manually specified to obtain the optimal result at the least cost. Different data volumes may require different hyper-parameters. When you want to fine-tune the model based on your own data, there are several parameter adjustment strategies for reference:
<a name="11-learning-rate"></a> <a name="11-learning-rate"></a>
### 2.1 Learning Rate ### 2.1 Learning Rate
The learning rate is one of the important hyperparameters for training neural networks. It represents the step length of the gradient moving to the optimal solution of the loss function in each iteration. The learning rate is one of the most important hyper-parameters for training neural networks. It represents the step length of the gradient moving towards the optimal solution of the loss function in each iteration.
A variety of learning rate update strategies are provided in PaddleOCR, which can be modified through configuration files, for example: A variety of learning rate update strategies are provided by PaddleOCR, which can be specified in configuration files. For example,
``` ```
Optimizer: Optimizer:
...@@ -46,16 +46,15 @@ Optimizer: ...@@ -46,16 +46,15 @@ Optimizer:
warmup_epoch: 5 warmup_epoch: 5
``` ```
Piecewise stands for piecewise constant attenuation. Different learning rates are specified in different learning stages, `Piecewise` stands for piece-wise constant attenuation. Different learning rates are specified in different learning stages, and the learning rate stay the same in each stage.
and the learning rate is the same in each stage.
warmup_epoch means that in the first 5 epochs, the learning rate will gradually increase from 0 to base_lr. For all strategies, please refer to the code [learning_rate.py](../../ppocr/optimizer/learning_rate.py). `warmup_epoch` means that in the first 5 epochs, the learning rate will be increased gradually from 0 to base_lr. For all strategies, please refer to the code [learning_rate.py](../../ppocr/optimizer/learning_rate.py).
<a name="12-regularization"></a> <a name="12-regularization"></a>
### 2.2 Regularization ### 2.2 Regularization
Regularization can effectively avoid algorithm overfitting. PaddleOCR provides L1 and L2 regularization methods. Regularization can effectively avoid algorithm over-fitting. PaddleOCR provides L1 and L2 regularization methods.
L1 and L2 regularization are the most commonly used regularization methods. L1 and L2 regularization are the most widely used regularization methods.
L1 regularization adds a regularization term to the objective function to reduce the sum of absolute values of the parameters; L1 regularization adds a regularization term to the objective function to reduce the sum of absolute values of the parameters;
while in L2 regularization, the purpose of adding a regularization term is to reduce the sum of squared parameters. while in L2 regularization, the purpose of adding a regularization term is to reduce the sum of squared parameters.
The configuration method is as follows: The configuration method is as follows:
...@@ -95,7 +94,7 @@ The current open source models, data sets and magnitudes are as follows: ...@@ -95,7 +94,7 @@ The current open source models, data sets and magnitudes are as follows:
- Chinese data set, LSVT street view data set crops the image according to the truth value, and performs position calibration, a total of 30w images. In addition, based on the LSVT corpus, 500w of synthesized data. - Chinese data set, LSVT street view data set crops the image according to the truth value, and performs position calibration, a total of 30w images. In addition, based on the LSVT corpus, 500w of synthesized data.
- Small language data set, using different corpora and fonts, respectively generated 100w synthetic data set, and using ICDAR-MLT as the verification set. - Small language data set, using different corpora and fonts, respectively generated 100w synthetic data set, and using ICDAR-MLT as the verification set.
Among them, the public data sets are all open source, users can search and download by themselves, or refer to [Chinese data set](./datasets.md), synthetic data is not open source, users can use open source synthesis tools to synthesize by themselves. Synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) etc. Among them, the public data sets are all open source, users can search and download by themselves, or refer to [Chinese data set](../doc_ch/datasets.md), synthetic data is not open source, users can use open source synthesis tools to synthesize by themselves. Synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) etc.
<a name="22-vertical-scene"></a> <a name="22-vertical-scene"></a>
...@@ -129,17 +128,17 @@ There are several experiences for reference when constructing the data set: ...@@ -129,17 +128,17 @@ There are several experiences for reference when constructing the data set:
**Q**: How to choose a suitable network input shape when training CRNN recognition? **Q**: How to choose a suitable network input shape when training CRNN recognition?
A: The general height is 32, the longest width is selected, there are two methods: A: The general height is 32, the longest width is selected, there are two methods:
(1) Calculate the aspect ratio distribution of training sample images. The selection of the maximum aspect ratio considers 80% of the training samples. (1) Calculate the aspect ratio distribution of training sample images. The selection of the maximum aspect ratio considers 80% of the training samples.
(2) Count the number of texts in training samples. The selection of the longest number of characters considers the training sample that satisfies 80%. Then the aspect ratio of Chinese characters is approximately considered to be 1, and that of English is 3:1, and the longest width is estimated. (2) Count the number of texts in training samples. The selection of the longest number of characters considers the training sample that satisfies 80%. Then the aspect ratio of Chinese characters is approximately considered to be 1, and that of English is 3:1, and the longest width is estimated.
**Q**: During the recognition training, the accuracy of the training set has reached 90, but the accuracy of the verification set has been kept at 70, what should I do? **Q**: During the recognition training, the accuracy of the training set has reached 90, but the accuracy of the verification set has been kept at 70, what should I do?
A: If the accuracy of the training set is 90 and the test set is more than 70, it should be over-fitting. There are two methods to try: A: If the accuracy of the training set is 90 and the test set is more than 70, it should be over-fitting. There are two methods to try:
(1) Add more augmentation methods or increase the [probability] of augmented prob (https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppocr/data/imaug/rec_img_aug.py#L341), The default is 0.4. (1) Add more augmentation methods or increase the [probability] of augmented prob (https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppocr/data/imaug/rec_img_aug.py#L341), The default is 0.4.
(2) Increase the [l2 dcay value] of the system (https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47) (2) Increase the [l2 dcay value] of the system (https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47)
**Q**: When the recognition model is trained, loss can drop normally, but acc is always 0 **Q**: When the recognition model is trained, loss can drop normally, but acc is always 0
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
- 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files). - 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files).
- 2021.4.8 release end-to-end text recognition algorithm [PGNet](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) which is published in AAAI 2021. Find tutorial [here](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/pgnet_en.md);release multi language recognition [models](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md), support more than 80 languages recognition; especically, the performance of [English recognition model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/models_list_en.md#English) is Optimized. - 2021.4.8 release end-to-end text recognition algorithm [PGNet](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) which is published in AAAI 2021. Find tutorial [here](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/pgnet_en.md);release multi language recognition [models](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md), support more than 80 languages recognition; especically, the performance of [English recognition model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/models_list_en.md#English) is Optimized.
- 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048). - 2021.1.21 update more than 25+ multilingual recognition models [models list](./models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](../../StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image. - 2020.12.15 update Data synthesis tool, i.e., [Style-Text](../../StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](../../PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly. - 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](../../PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941 - 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
......
...@@ -2551,7 +2551,7 @@ ...@@ -2551,7 +2551,7 @@
"\n", "\n",
"Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n", "Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n",
"\n", "\n",
"## 4.1 Padde Serving简介\n", "## 4.1 Paddle Serving简介\n",
"\n", "\n",
"Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n", "Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n",
"\n", "\n",
...@@ -42,12 +42,14 @@ __all__ = [ ...@@ -42,12 +42,14 @@ __all__ = [
] ]
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.3.0.2' VERSION = '2.4'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
MODEL_URLS = { MODEL_URLS = {
'OCR': { 'OCR': {
'PP-OCRv2': { 'PP-OCRv2': {
...@@ -190,6 +192,7 @@ def parse_args(mMain=True): ...@@ -190,6 +192,7 @@ def parse_args(mMain=True):
parser.add_argument( parser.add_argument(
"--ocr_version", "--ocr_version",
type=str, type=str,
choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2', default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: ' help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. ' '1. PP-OCRv2 Support Chinese detection and recognition model. '
...@@ -198,6 +201,7 @@ def parse_args(mMain=True): ...@@ -198,6 +201,7 @@ def parse_args(mMain=True):
parser.add_argument( parser.add_argument(
"--structure_version", "--structure_version",
type=str, type=str,
choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE', default='STRUCTURE',
help='Model version, the current model support list is as follows:' help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.') ' 1. STRUCTURE Support en table structure model.')
...@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang): ...@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else: else:
raise NotImplementedError raise NotImplementedError
model_urls = MODEL_URLS[type] model_urls = MODEL_URLS[type]
if version not in model_urls: if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
if model_type not in model_urls[version]: if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]: if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error('{} models is not support, we only support {}'.format( logger.error('{} models is not support, we only support {}'.format(
model_type, model_urls[DEFAULT_MODEL_VERSION].keys())) model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1) sys.exit(-1)
if lang not in model_urls[version][model_type]: if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]: if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning(
'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION version = DEFAULT_MODEL_VERSION
else: else:
logger.error( logger.error(
...@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem):
""" """
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
params.use_gpu = check_gpu(params.use_gpu) params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
...@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem):
ocr with paddleocr ocr with paddleocr
args: args:
img: img for ocr, support ndarray, img_path and list or ndarray img: img for ocr, support ndarray, img_path and list or ndarray
det: use text detection or not, if false, only rec will be exec. default is True det: use text detection or not. If false, only rec will be exec. Default is True
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not. If false, only det will be exec. Default is True
cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
...@@ -398,6 +399,8 @@ class PPStructure(OCRSystem): ...@@ -398,6 +399,8 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu) params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log: if not params.show_log:
......
...@@ -20,6 +20,7 @@ from __future__ import unicode_literals ...@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import os import os
import sys import sys
import numpy as np import numpy as np
import skimage
import paddle import paddle
import signal import signal
import random import random
...@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
num_workers=num_workers, num_workers=num_workers,
return_list=True, return_list=True,
use_shared_memory=use_shared_memory) use_shared_memory=use_shared_memory,
collate_fn=collate_fn)
# support exit using ctrl+c # support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGINT, term_mp)
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
import paddle import paddle
import numbers import numbers
import numpy as np import numpy as np
from collections import defaultdict
class DataCollator: class DictCollator(object):
""" """
data batch data batch
""" """
def __call__(self, batch): def __call__(self, batch):
data_dict = {} # todo:support batch operators
data_dict = defaultdict(list)
to_tensor_keys = [] to_tensor_keys = []
for sample in batch: for sample in batch:
for k, v in sample.items(): for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys: if k not in to_tensor_keys:
to_tensor_keys.append(k) to_tensor_keys.append(k)
...@@ -36,3 +36,23 @@ class DataCollator: ...@@ -36,3 +36,23 @@ class DataCollator:
for k in to_tensor_keys: for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k]) data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict return data_dict
class ListCollator(object):
"""
data batch
"""
def __call__(self, batch):
# todo:support batch operators
data_dict = defaultdict(list)
to_tensor_idxs = []
for sample in batch:
for idx, v in enumerate(sample):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
...@@ -34,6 +34,8 @@ from .sast_process import * ...@@ -34,6 +34,8 @@ from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import * from .gen_table_mask import *
from .vqa import *
def transform(data, ops=None): def transform(data, ops=None):
""" transform """ """ transform """
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import numpy as np import numpy as np
import string import string
from shapely.geometry import LineString, Point, Polygon from shapely.geometry import LineString, Point, Polygon
...@@ -736,7 +737,7 @@ class TableLabelEncode(object): ...@@ -736,7 +737,7 @@ class TableLabelEncode(object):
% beg_or_end % beg_or_end
else: else:
assert False, "Unsupport type %s in char_or_elem" \ assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem % char_or_elem
return idx return idx
...@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode): ...@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode):
def get_ignored_tokens(self): def get_ignored_tokens(self):
return [self.padding_idx] return [self.padding_idx]
class VQATokenLabelEncode(object):
"""
Label encode for NLP VQA methods
"""
def __init__(self,
class_path,
contains_re=False,
add_special_ids=False,
algorithm='LayoutXLM',
infer_mode=False,
ocr_engine=None,
**kwargs):
super(VQATokenLabelEncode, self).__init__()
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
from ppocr.utils.utility import load_vqa_bio_label_maps
tokenizer_dict = {
'LayoutXLM': {
'class': LayoutXLMTokenizer,
'pretrained_model': 'layoutxlm-base-uncased'
},
'LayoutLM': {
'class': LayoutLMTokenizer,
'pretrained_model': 'layoutlm-base-uncased'
}
}
self.contains_re = contains_re
tokenizer_config = tokenizer_dict[algorithm]
self.tokenizer = tokenizer_config['class'].from_pretrained(
tokenizer_config['pretrained_model'])
self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
self.add_special_ids = add_special_ids
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
def __call__(self, data):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
height, width, _ = data['image'].shape
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
segment_offset_id = []
gt_label_list = []
entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
if train_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
bbox = self._smooth_box(info["bbox"], height, width)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
# parse label
if not self.infer_mode:
label = info['label']
gt_label = self._parse_label(label, encode_res)
# construct entities for re
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
if not self.infer_mode:
gt_label_list.extend(gt_label)
data['input_ids'] = input_ids_list
data['token_type_ids'] = token_type_ids_list
data['bbox'] = bbox_list
data['attention_mask'] = [1] * len(input_ids_list)
data['labels'] = gt_label_list
data['segment_offset_id'] = segment_offset_id
data['tokenizer_params'] = dict(
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities
if train_re:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
else:
info = data['label']
# read text info
info_dict = json.loads(info)
return info_dict["ocr_info"]
def _smooth_box(self, bbox, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
return bbox
def _parse_label(self, label, encode_res):
gt_label = []
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
return gt_label
...@@ -23,7 +23,6 @@ import sys ...@@ -23,7 +23,6 @@ import sys
import six import six
import cv2 import cv2
import numpy as np import numpy as np
import fasttext
class DecodeImage(object): class DecodeImage(object):
...@@ -136,6 +135,7 @@ class ToCHWImage(object): ...@@ -136,6 +135,7 @@ class ToCHWImage(object):
class Fasttext(object): class Fasttext(object):
def __init__(self, path="None", **kwargs): def __init__(self, path="None", **kwargs):
import fasttext
self.fast_model = fasttext.load_model(path) self.fast_model = fasttext.load_model(path)
def __call__(self, data): def __call__(self, data):
...@@ -170,17 +170,19 @@ class Resize(object): ...@@ -170,17 +170,19 @@ class Resize(object):
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
text_polys = data['polys'] if 'polys' in data:
text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img) img_resize, [ratio_h, ratio_w] = self.resize_image(img)
new_boxes = [] if 'polys' in data:
for box in text_polys: new_boxes = []
new_box = [] for box in text_polys:
for cord in box: new_box = []
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) for cord in box:
new_boxes.append(new_box) new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
new_boxes.append(new_box)
data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize data['image'] = img_resize
data['polys'] = np.array(new_boxes, dtype=np.float32)
return data return data
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
__all__ = [
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
]
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQASerTokenChunk(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode
def __call__(self, data):
encoded_inputs_all = []
seq_len = len(data['input_ids'])
for index in range(0, seq_len, self.max_seq_len):
chunk_beg = index
chunk_end = min(index + self.max_seq_len, seq_len)
encoded_inputs_example = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
encoded_inputs_example[key] = data[key]
else:
encoded_inputs_example[key] = data[key][chunk_beg:
chunk_end]
else:
encoded_inputs_example[key] = data[key]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all[0]
class VQAReTokenChunk(object):
def __init__(self,
max_seq_len=512,
entities_labels=None,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.entities_labels = {
'HEADER': 0,
'QUESTION': 1,
'ANSWER': 2
} if entities_labels is None else entities_labels
self.infer_mode = infer_mode
def __call__(self, data):
# prepare data
entities = data.pop('entities')
relations = data.pop('relations')
encoded_inputs_all = []
for index in range(0, len(data["input_ids"]), self.max_seq_len):
item = {}
for key in data:
if key in [
'label', 'input_ids', 'labels', 'token_type_ids',
'bbox', 'attention_mask'
]:
if self.infer_mode and key == 'labels':
item[key] = data[key]
else:
item[key] = data[key][index:index + self.max_seq_len]
else:
item[key] = data[key]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + self.max_seq_len and
index <= entity["end"] < index + self.max_seq_len):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + self.max_seq_len
and index <= relation["end_index"] <
index + self.max_seq_len):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": self.reformat(entities_in_this_span),
"relations": self.reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all[0]
def reformat(self, data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
class VQATokenPad(object):
def __init__(self,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False,
infer_mode=False,
**kwargs):
self.max_seq_len = max_seq_len
self.pad_to_max_seq_len = max_seq_len
self.return_attention_mask = return_attention_mask
self.return_token_type_ids = return_token_type_ids
self.truncation_strategy = truncation_strategy
self.return_overflowing_tokens = return_overflowing_tokens
self.return_special_tokens_mask = return_special_tokens_mask
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.infer_mode = infer_mode
def __call__(self, data):
needs_to_be_padded = self.pad_to_max_seq_len and len(data[
"input_ids"]) < self.max_seq_len
if needs_to_be_padded:
if 'tokenizer_params' in data:
tokenizer_params = data.pop('tokenizer_params')
else:
tokenizer_params = dict(
padding_side='right', pad_token_type_id=0, pad_token_id=1)
difference = self.max_seq_len - len(data["input_ids"])
if tokenizer_params['padding_side'] == 'right':
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data[
"input_ids"]) + [0] * difference
if self.return_token_type_ids:
data["token_type_ids"] = (
data["token_type_ids"] +
[tokenizer_params['pad_token_type_id']] * difference)
if self.return_special_tokens_mask:
data["special_tokens_mask"] = data[
"special_tokens_mask"] + [1] * difference
data["input_ids"] = data["input_ids"] + [
tokenizer_params['pad_token_id']
] * difference
if not self.infer_mode:
data["labels"] = data[
"labels"] + [self.pad_token_label_id] * difference
data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
elif tokenizer_params['padding_side'] == 'left':
if self.return_attention_mask:
data["attention_mask"] = [0] * difference + [
1
] * len(data["input_ids"])
if self.return_token_type_ids:
data["token_type_ids"] = (
[tokenizer_params['pad_token_type_id']] * difference +
data["token_type_ids"])
if self.return_special_tokens_mask:
data["special_tokens_mask"] = [
1
] * difference + data["special_tokens_mask"]
data["input_ids"] = [tokenizer_params['pad_token_id']
] * difference + data["input_ids"]
if not self.infer_mode:
data["labels"] = [self.pad_token_label_id
] * difference + data["labels"]
data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
else:
if self.return_attention_mask:
data["attention_mask"] = [1] * len(data["input_ids"])
for key in data:
if key in [
'input_ids', 'labels', 'token_type_ids', 'bbox',
'attention_mask'
]:
if self.infer_mode:
if key != 'labels':
length = min(len(data[key]), self.max_seq_len)
data[key] = data[key][:length]
else:
continue
data[key] = np.array(data[key], dtype='int64')
return data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class VQAReTokenRelation(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
"""
build relations
"""
entities = data['entities']
relations = data['relations']
id2label = data.pop('id2label')
empty_entity = data.pop('empty_entity')
entity_id_to_index_map = data.pop('entity_id_to_index_map')
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": self.get_relation_span(rel, entities)[0],
"end_index": self.get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
data['relations'] = relations
return data
def get_relation_span(self, rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): ...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np.random.shuffle(self.data_idx_order_list) np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir): def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
......
...@@ -49,6 +49,8 @@ class PGDataSet(Dataset): ...@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
......
...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): ...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): ...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset):
prob = self.img_select_prob[file_name] prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1): if prob < random.uniform(0, 1):
select_flag = False select_flag = False
if self.table_select_type: if self.table_select_type:
structure = info['html']['structure']['tokens'].copy() structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure) structure_str = ''.join(structure)
...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): ...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset):
table_type = "complex" table_type = "complex"
if table_type == "complex": if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): if self.table_select_prob < random.uniform(0, 1):
select_flag = False select_flag = False
if select_flag: if select_flag:
cells = info['html']['cells'].copy() cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy() structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure} data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path): if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path)) raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
......
...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset): ...@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset): ...@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list): def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
...@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset): ...@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self): def get_ext_data(self):
ext_data_num = 0 ext_data_num = 0
for op in self.ops: for op in self.ops:
...@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset): ...@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
...@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset): ...@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1]!=4: if data is None or data['polys'].shape[1] != 4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset): ...@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
......
...@@ -16,6 +16,9 @@ import copy ...@@ -16,6 +16,9 @@ import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
# basic_loss
from .basic_loss import LossFromOutput
# det loss # det loss
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss from .det_east_loss import EASTLoss
...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss ...@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def build_loss(config): def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss' 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer): ...@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def forward(self, x, y): def forward(self, x, y):
return self.loss_func(x, y) return self.loss_func(x, y)
class LossFromOutput(nn.Layer):
def __init__(self, key='loss', reduction='none'):
super().__init__()
self.key = key
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,24 +12,31 @@ ...@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn from paddle import nn
class SERLoss(nn.Layer): class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes): def __init__(self, num_classes):
super().__init__() super().__init__()
self.loss_class = nn.CrossEntropyLoss() self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, labels, outputs, attention_mask): def forward(self, predicts, batch):
labels = batch[1]
attention_mask = batch[4]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = outputs.reshape( active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss] [-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss] active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels) loss = self.loss_class(active_outputs, active_labels)
else: else:
loss = self.loss_class( loss = self.loss_class(
outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ])) predicts.reshape([-1, self.num_classes]),
return loss labels.reshape([-1, ]))
return {'loss': loss}
...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric ...@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric from .distillation_metric import DistillationMetric
from .table_metric import TableMetric from .table_metric import TableMetric
from .kie_metric import KIEMetric from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric' "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
class ClsMetric(object): class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred_label, *args, **kwargs): def __call__(self, pred_label, *args, **kwargs):
...@@ -28,7 +29,7 @@ class ClsMetric(object): ...@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num += 1 all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num / all_num, } return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -36,7 +37,7 @@ class ClsMetric(object): ...@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0 'acc': 0
} }
""" """
acc = self.correct_num / self.all_num acc = self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
...@@ -20,6 +20,7 @@ class RecMetric(object): ...@@ -20,6 +20,7 @@ class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs): def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter self.is_filter = is_filter
self.eps = 1e-5
self.reset() self.reset()
def _normalize_text(self, text): def _normalize_text(self, text):
...@@ -47,8 +48,8 @@ class RecMetric(object): ...@@ -47,8 +48,8 @@ class RecMetric(object):
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
return { return {
'acc': correct_num / all_num, 'acc': correct_num / (all_num + self.eps),
'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3) 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
} }
def get_metric(self): def get_metric(self):
...@@ -58,8 +59,8 @@ class RecMetric(object): ...@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = 1.0 * self.correct_num / (self.all_num + 1e-3) acc = 1.0 * self.correct_num / (self.all_num + self.eps)
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3) norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
......
...@@ -12,9 +12,12 @@ ...@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
class TableMetric(object): class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred, batch, *args, **kwargs):
...@@ -31,9 +34,7 @@ class TableMetric(object): ...@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num += 1 correct_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return { return {'acc': correct_num * 1.0 / (all_num + self.eps), }
'acc': correct_num * 1.0 / all_num,
}
def get_metric(self): def get_metric(self):
""" """
...@@ -41,7 +42,7 @@ class TableMetric(object): ...@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0, 'acc': 0,
} }
""" """
acc = 1.0 * self.correct_num / self.all_num acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset() self.reset()
return {'acc': acc} return {'acc': acc}
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQAReTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
pred_relations, relations, entities = preds
self.pred_relations_list.extend(pred_relations)
self.relations_list.extend(relations)
self.entities_list.extend(entities)
def get_metric(self):
gt_relations = []
for b in range(len(self.relations_list)):
rel_sent = []
for head, tail in zip(self.relations_list[b]["head"],
self.relations_list[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
"head_id"]]
rel["tail_id"] = tail
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["tail_type"] = self.entities_list[b]["label"][rel[
"tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries")
metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"hmean": re_metrics["ALL"]["f1"],
}
self.reset()
return metrics
def reset(self):
self.pred_relations_list = []
self.relations_list = []
self.entities_list = []
def re_score(self, pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent
if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
return scores
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
__all__ = ['KIEMetric']
class VQASerTokenMetric(object):
def __init__(self, main_indicator='hmean', **kwargs):
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
preds, labels = preds
self.pred_list.extend(preds)
self.gt_list.extend(labels)
def get_metric(self):
from seqeval.metrics import f1_score, precision_score, recall_score
metircs = {
"precision": precision_score(self.gt_list, self.pred_list),
"recall": recall_score(self.gt_list, self.pred_list),
"hmean": f1_score(self.gt_list, self.pred_list),
}
self.reset()
return metircs
def reset(self):
self.pred_list = []
self.gt_list = []
...@@ -63,8 +63,12 @@ class BaseModel(nn.Layer): ...@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
in_channels = self.neck.out_channels in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls # # build head, head is need for det, rec and cls
config["Head"]['in_channels'] = in_channels if 'Head' not in config or config['Head'] is None:
self.head = build_head(config["Head"]) self.use_head = False
else:
self.use_head = True
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False) self.return_all_feats = config.get("return_all_feats", False)
...@@ -77,7 +81,8 @@ class BaseModel(nn.Layer): ...@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
x = self.head(x, targets=data) if self.use_head:
x = self.head(x, targets=data)
if isinstance(x, dict): if isinstance(x, dict):
y.update(x) y.update(x)
else: else:
......
...@@ -29,9 +29,10 @@ def build_backbone(config, model_type): ...@@ -29,9 +29,10 @@ def build_backbone(config, model_type):
from .rec_nrtr_mtb import MTB from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31 from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER from .rec_resnet_aster import ResNet_ASTER
from .rec_micronet import MicroNet
support_dict = [ support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
"ResNet31", "ResNet_ASTER" "ResNet31", "ResNet_ASTER", 'MicroNet'
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
...@@ -43,6 +44,9 @@ def build_backbone(config, model_type): ...@@ -43,6 +44,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == 'vqa':
from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
else: else:
raise NotImplementedError raise NotImplementedError
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible
M0_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
[2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1],
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1],
[2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1],
[2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1],
[1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2],
]
M1_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1],
[2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1],
[2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1],
[2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1],
[1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2],
]
M2_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1],
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
[1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1],
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1],
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2],
[1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2],
[2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
[1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2],
[1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2],
]
M3_cfgs = [
# s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
[2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1],
[2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1],
[1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1],
[2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1],
[1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2],
[1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2],
[1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2],
[1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2],
[1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2],
[1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2],
[1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2],
[1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2],
]
def get_micronet_config(mode):
return eval(mode + '_cfgs')
class MaxGroupPooling(nn.Layer):
def __init__(self, channel_per_group=2):
super(MaxGroupPooling, self).__init__()
self.channel_per_group = channel_per_group
def forward(self, x):
if self.channel_per_group == 1:
return x
# max op
b, c, h, w = x.shape
# reshape
y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w])
out = paddle.max(y, axis=2)
return out
class SpatialSepConvSF(nn.Layer):
def __init__(self, inp, oups, kernel_size, stride):
super(SpatialSepConvSF, self).__init__()
oup1, oup2 = oups
self.conv = nn.Sequential(
nn.Conv2D(
inp,
oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
bias_attr=False,
groups=1),
nn.BatchNorm2D(oup1),
nn.Conv2D(
oup1,
oup1 * oup2, (1, kernel_size), (1, stride),
(0, kernel_size // 2),
bias_attr=False,
groups=oup1),
nn.BatchNorm2D(oup1 * oup2),
ChannelShuffle(oup1), )
def forward(self, x):
out = self.conv(x)
return out
class ChannelShuffle(nn.Layer):
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
b, c, h, w = x.shape
channels_per_group = c // self.groups
# reshape
x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w])
x = paddle.transpose(x, (0, 2, 1, 3, 4))
out = paddle.reshape(x, [b, -1, h, w])
return out
class StemLayer(nn.Layer):
def __init__(self, inp, oup, stride, groups=(4, 4)):
super(StemLayer, self).__init__()
g1, g2 = groups
self.stem = nn.Sequential(
SpatialSepConvSF(inp, groups, 3, stride),
MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
def forward(self, x):
out = self.stem(x)
return out
class DepthSpatialSepConv(nn.Layer):
def __init__(self, inp, expand, kernel_size, stride):
super(DepthSpatialSepConv, self).__init__()
exp1, exp2 = expand
hidden_dim = inp * exp1
oup = inp * exp1 * exp2
self.conv = nn.Sequential(
nn.Conv2D(
inp,
inp * exp1, (kernel_size, 1), (stride, 1),
(kernel_size // 2, 0),
bias_attr=False,
groups=inp),
nn.BatchNorm2D(inp * exp1),
nn.Conv2D(
hidden_dim,
oup, (1, kernel_size),
1, (0, kernel_size // 2),
bias_attr=False,
groups=hidden_dim),
nn.BatchNorm2D(oup))
def forward(self, x):
x = self.conv(x)
return x
class GroupConv(nn.Layer):
def __init__(self, inp, oup, groups=2):
super(GroupConv, self).__init__()
self.inp = inp
self.oup = oup
self.groups = groups
self.conv = nn.Sequential(
nn.Conv2D(
inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
nn.BatchNorm2D(oup))
def forward(self, x):
x = self.conv(x)
return x
class DepthConv(nn.Layer):
def __init__(self, inp, oup, kernel_size, stride):
super(DepthConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2D(
inp,
oup,
kernel_size,
stride,
kernel_size // 2,
bias_attr=False,
groups=inp),
nn.BatchNorm2D(oup))
def forward(self, x):
out = self.conv(x)
return out
class DYShiftMax(nn.Layer):
def __init__(self,
inp,
oup,
reduction=4,
act_max=1.0,
act_relu=True,
init_a=[0.0, 0.0],
init_b=[0.0, 0.0],
relu_before_pool=False,
g=None,
expansion=False):
super(DYShiftMax, self).__init__()
self.oup = oup
self.act_max = act_max * 2
self.act_relu = act_relu
self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
nn.Sequential(), nn.AdaptiveAvgPool2D(1))
self.exp = 4 if act_relu else 2
self.init_a = init_a
self.init_b = init_b
# determine squeeze
squeeze = make_divisible(inp // reduction, 4)
if squeeze < 4:
squeeze = 4
self.fc = nn.Sequential(
nn.Linear(inp, squeeze),
nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
if g is None:
g = 1
self.g = g[1]
if self.g != 1 and expansion:
self.g = inp // self.g
self.gc = inp // self.g
index = paddle.to_tensor([range(inp)])
index = paddle.reshape(index, [1, inp, 1, 1])
index = paddle.reshape(index, [1, self.g, self.gc, 1, 1])
indexgs = paddle.split(index, [1, self.g - 1], axis=1)
indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1)
indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2)
indexs = paddle.concat((indexs[1], indexs[0]), axis=2)
self.index = paddle.reshape(indexs, [inp])
self.expansion = expansion
def forward(self, x):
x_in = x
x_out = x
b, c, _, _ = x_in.shape
y = self.avg_pool(x_in)
y = paddle.reshape(y, [b, c])
y = self.fc(y)
y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1])
y = (y - 0.5) * self.act_max
n2, c2, h2, w2 = x_out.shape
x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :])
if self.exp == 4:
temp = y.shape
a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1)
a1 = a1 + self.init_a[0]
a2 = a2 + self.init_a[1]
b1 = b1 + self.init_b[0]
b2 = b2 + self.init_b[1]
z1 = x_out * a1 + x2 * b1
z2 = x_out * a2 + x2 * b2
out = paddle.maximum(z1, z2)
elif self.exp == 2:
temp = y.shape
a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1)
a1 = a1 + self.init_a[0]
b1 = b1 + self.init_b[0]
out = x_out * a1 + x2 * b1
return out
class DYMicroBlock(nn.Layer):
def __init__(self,
inp,
oup,
kernel_size=3,
stride=1,
ch_exp=(2, 2),
ch_per_group=4,
groups_1x1=(1, 1),
depthsep=True,
shuffle=False,
activation_cfg=None):
super(DYMicroBlock, self).__init__()
self.identity = stride == 1 and inp == oup
y1, y2, y3 = activation_cfg['dy']
act_reduction = 8 * activation_cfg['ratio']
init_a = activation_cfg['init_a']
init_b = activation_cfg['init_b']
t1 = ch_exp
gs1 = ch_per_group
hidden_fft, g1, g2 = groups_1x1
hidden_dim2 = inp * t1[0] * t1[1]
if gs1[0] == 0:
self.layers = nn.Sequential(
DepthSpatialSepConv(inp, t1, kernel_size, stride),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y2 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=False) if y2 > 0 else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
ChannelShuffle(hidden_dim2 // 2)
if shuffle and y2 != 0 else nn.Sequential(),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
oup,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction // 2,
init_b=[0.0, 0.0],
g=(g1, g2),
expansion=False) if y3 > 0 else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
ChannelShuffle(oup // 2)
if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
elif g2 == 0:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction,
init_b=[0.0, 0.0],
g=gs1,
expansion=False) if y3 > 0 else nn.Sequential(), )
else:
self.layers = nn.Sequential(
GroupConv(inp, hidden_dim2, gs1),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y1 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=False) if y1 > 0 else nn.ReLU6(),
ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
if depthsep else
DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
nn.Sequential(),
DYShiftMax(
hidden_dim2,
hidden_dim2,
act_max=2.0,
act_relu=True if y2 == 2 else False,
init_a=init_a,
reduction=act_reduction,
init_b=init_b,
g=gs1,
expansion=True) if y2 > 0 else nn.ReLU6(),
ChannelShuffle(hidden_dim2 // 4)
if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
GroupConv(hidden_dim2, oup, (g1, g2)),
DYShiftMax(
oup,
oup,
act_max=2.0,
act_relu=False,
init_a=[1.0, 0.0],
reduction=act_reduction // 2
if oup < hidden_dim2 else act_reduction,
init_b=[0.0, 0.0],
g=(g1, g2),
expansion=False) if y3 > 0 else nn.Sequential(),
ChannelShuffle(g2) if shuffle else nn.Sequential(),
ChannelShuffle(oup // 2)
if shuffle and y3 != 0 else nn.Sequential(), )
def forward(self, x):
identity = x
out = self.layers(x)
if self.identity:
out = out + identity
return out
class MicroNet(nn.Layer):
"""
the MicroNet backbone network for recognition module.
Args:
mode(str): {'M0', 'M1', 'M2', 'M3'}
Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
Default: 'M3'.
"""
def __init__(self, mode='M3', **kwargs):
super(MicroNet, self).__init__()
self.cfgs = get_micronet_config(mode)
activation_cfg = {}
if mode == 'M0':
input_channel = 4
stem_groups = 2, 2
out_ch = 384
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M1':
input_channel = 6
stem_groups = 3, 2
out_ch = 576
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M2':
input_channel = 8
stem_groups = 4, 2
out_ch = 768
activation_cfg['init_a'] = 1.0, 1.0
activation_cfg['init_b'] = 0.0, 0.0
elif mode == 'M3':
input_channel = 12
stem_groups = 4, 3
out_ch = 432
activation_cfg['init_a'] = 1.0, 0.5
activation_cfg['init_b'] = 0.0, 0.5
else:
raise NotImplementedError("mode[" + mode +
"_model] is not implemented!")
layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
for idx, val in enumerate(self.cfgs):
s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val
t1 = (c1, c2)
gs1 = (g1, g2)
gs2 = (c3, g3, g4)
activation_cfg['dy'] = [y1, y2, y3]
activation_cfg['ratio'] = r
output_channel = c
layers.append(
DYMicroBlock(
input_channel,
output_channel,
kernel_size=ks,
stride=s,
ch_exp=t1,
ch_per_group=gs1,
groups_1x1=gs2,
depthsep=True,
shuffle=True,
activation_cfg=activation_cfg, ))
input_channel = output_channel
for i in range(1, n):
layers.append(
DYMicroBlock(
input_channel,
output_channel,
kernel_size=ks,
stride=1,
ch_exp=t1,
ch_per_group=gs1,
groups_1x1=gs2,
depthsep=True,
shuffle=True,
activation_cfg=activation_cfg, ))
input_channel = output_channel
self.features = nn.Sequential(*layers)
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(out_ch)
def forward(self, x):
x = self.features(x)
x = self.pool(x)
return x
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased'
}
class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints)
else:
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser':
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
image=x[3],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
head_mask=None,
labels=None)
return x[0]
class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
pretrained,
checkpoints,
num_classes=num_classes)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[2],
attention_mask=x[4],
token_type_ids=x[5],
position_ids=None,
output_hidden_states=False)
return x
class LayoutXLMForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
LayoutXLMForRelationExtraction,
're', pretrained, checkpoints)
def forward(self, x):
x = self.model(
input_ids=x[0],
bbox=x[1],
labels=None,
image=x[2],
attention_mask=x[3],
token_type_ids=x[4],
position_ids=None,
head_mask=None,
entities=x[5],
relations=x[6])
return x
...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization # step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None: if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer') reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name')
if not hasattr(regularizer, reg_name):
reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)() reg = getattr(regularizer, reg_name)(**reg_config)()
else: else:
reg = None reg = None
......
...@@ -18,7 +18,7 @@ from __future__ import print_function ...@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
from paddle.optimizer import lr from paddle.optimizer import lr
from .lr_scheduler import CyclicalCosineDecay from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
class Linear(object): class Linear(object):
...@@ -226,3 +226,53 @@ class CyclicalCosine(object): ...@@ -226,3 +226,53 @@ class CyclicalCosine(object):
end_lr=self.learning_rate, end_lr=self.learning_rate,
last_epoch=self.last_epoch) last_epoch=self.last_epoch)
return learning_rate return learning_rate
class OneCycle(object):
"""
One Cycle learning rate decay
Args:
max_lr(float): Upper learning rate boundaries
epochs(int): total training epochs
step_each_epoch(int): steps each epoch
anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
Default: ‘cos’
three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
max_lr,
epochs,
step_each_epoch,
anneal_strategy='cos',
three_phase=False,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(OneCycle, self).__init__()
self.max_lr = max_lr
self.epochs = epochs
self.steps_per_epoch = step_each_epoch
self.anneal_strategy = anneal_strategy
self.three_phase = three_phase
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = OneCycleDecay(
max_lr=self.max_lr,
epochs=self.epochs,
steps_per_epoch=self.steps_per_epoch,
anneal_strategy=self.anneal_strategy,
three_phase=self.three_phase,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.max_lr,
last_epoch=self.last_epoch)
return learning_rate
\ No newline at end of file
...@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler): ...@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler):
lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \ lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
(1 + math.cos(math.pi * reletive_epoch / self.cycle)) (1 + math.cos(math.pi * reletive_epoch / self.cycle))
return lr return lr
class OneCycleDecay(LRScheduler):
"""
One Cycle learning rate decay
A learning rate which can be referred in https://arxiv.org/abs/1708.07120
Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
"""
def __init__(self,
max_lr,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy='cos',
div_factor=25.,
final_div_factor=1e4,
three_phase=False,
last_epoch=-1,
verbose=False):
# Validate total_steps
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError(
"Expected positive integer epochs, but got {}".format(epochs))
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError(
"Expected positive integer steps_per_epoch, but got {}".format(
steps_per_epoch))
self.total_steps = epochs * steps_per_epoch
self.max_lr = max_lr
self.initial_lr = self.max_lr / div_factor
self.min_lr = self.initial_lr / final_div_factor
if three_phase:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': self.initial_lr,
'end_lr': self.max_lr,
},
{
'end_step': float(2 * pct_start * self.total_steps) - 2,
'start_lr': self.max_lr,
'end_lr': self.initial_lr,
},
{
'end_step': self.total_steps - 1,
'start_lr': self.initial_lr,
'end_lr': self.min_lr,
},
]
else:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
'start_lr': self.initial_lr,
'end_lr': self.max_lr,
},
{
'end_step': self.total_steps - 1,
'start_lr': self.max_lr,
'end_lr': self.min_lr,
},
]
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError(
"Expected float between 0 and 1 pct_start, but got {}".format(
pct_start))
# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(
"anneal_strategy must by one of 'cos' or 'linear', instead got {}".
format(anneal_strategy))
elif anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear
super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
def _annealing_cos(self, start, end, pct):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return (end - start) * pct + start
def get_lr(self):
computed_lr = 0.0
step_num = self.last_epoch
if step_num > self.total_steps:
raise ValueError(
"Tried to step {} times. The specified number of total steps is {}"
.format(step_num + 1, self.total_steps))
start_step = 0
for i, phase in enumerate(self._schedule_phases):
end_step = phase['end_step']
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
computed_lr = self.anneal_func(phase['start_lr'],
phase['end_lr'], pct)
break
start_step = phase['end_step']
return computed_lr
...@@ -158,3 +158,38 @@ class Adadelta(object): ...@@ -158,3 +158,38 @@ class Adadelta(object):
name=self.name, name=self.name,
parameters=parameters) parameters=parameters)
return opt return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
weight_decay=0.01,
grad_clip=None,
name=None,
lazy_mode=False,
**kwargs):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.learning_rate = learning_rate
self.weight_decay = 0.01 if weight_decay is None else weight_decay
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
def __call__(self, parameters):
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
parameters=parameters)
return opt
...@@ -29,24 +29,23 @@ class L1Decay(object): ...@@ -29,24 +29,23 @@ class L1Decay(object):
def __init__(self, factor=0.0): def __init__(self, factor=0.0):
super(L1Decay, self).__init__() super(L1Decay, self).__init__()
self.regularization_coeff = factor self.coeff = factor
def __call__(self): def __call__(self):
reg = paddle.regularizer.L1Decay(self.regularization_coeff) reg = paddle.regularizer.L1Decay(self.coeff)
return reg return reg
class L2Decay(object): class L2Decay(object):
""" """
L2 Weight Decay Regularization, which encourages the weights to be sparse. L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
Args: Args:
factor(float): regularization coeff. Default:0.0. factor(float): regularization coeff. Default:0.0.
""" """
def __init__(self, factor=0.0): def __init__(self, factor=0.0):
super(L2Decay, self).__init__() super(L2Decay, self).__init__()
self.regularization_coeff = factor self.coeff = float(factor)
def __call__(self): def __call__(self):
reg = paddle.regularizer.L2Decay(self.regularization_coeff) return self.coeff
return reg \ No newline at end of file
...@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di ...@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None): ...@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode' 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
class VQAReTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, **kwargs):
super(VQAReTokenLayoutLMPostProcess, self).__init__()
def __call__(self, preds, label=None, *args, **kwargs):
if label is not None:
return self._metric(preds, label)
else:
return self._infer(preds, *args, **kwargs)
def _metric(self, preds, label):
return preds['pred_relations'], label[6], label[5]
def _infer(self, preds, *args, **kwargs):
ser_results = kwargs['ser_results']
entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
pred_relations = preds['pred_relations']
# merge relations and ocr info
results = []
for pred_relation, ser_result, entity_idx_dict in zip(
pred_relations, ser_results, entity_idx_dict_batch):
result = []
used_tail_id = []
for relation in pred_relation:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from ppocr.utils.utility import load_vqa_bio_label_maps
class VQASerTokenLayoutLMPostProcess(object):
""" Convert between text-label and text-index """
def __init__(self, class_path, **kwargs):
super(VQASerTokenLayoutLMPostProcess, self).__init__()
label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
self.id2label_map_for_show = dict()
for key in self.label2id_map_for_draw:
val = self.label2id_map_for_draw[key]
if key == "O":
self.id2label_map_for_show[val] = key
if key.startswith("B-") or key.startswith("I-"):
self.id2label_map_for_show[val] = key[2:]
else:
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
return self._metric(preds, batch[1])
else:
return self._infer(preds, **kwargs)
def _metric(self, preds, label):
pred_idxs = preds.argmax(axis=2)
decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
for i in range(pred_idxs.shape[0]):
for j in range(pred_idxs.shape[1]):
if label[i, j] != -100:
label_decode_out_list[i].append(self.id2label_map[label[i,
j]])
decode_out_list[i].append(self.id2label_map[pred_idxs[i,
j]])
return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip(
preds, attention_masks, segment_offset_ids, ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = pred[start_id:end_id]
curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger): ...@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_model(config, model, optimizer=None): def load_model(config, model, optimizer=None, model_type='det'):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
...@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None): ...@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
best_model_dict = {} best_model_dict = {}
if model_type == 'vqa':
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
if os.path.exists(os.path.join(checkpoints, 'metric.states')):
with open(os.path.join(checkpoints, 'metric.states'),
'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
if optimizer is not None:
if checkpoints[-1] in ['/', '\\']:
checkpoints = checkpoints[:-1]
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints: if checkpoints:
if checkpoints.endswith('.pdparams'): if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '') checkpoints = checkpoints.replace('.pdparams', '')
...@@ -111,13 +138,16 @@ def load_pretrained_params(model, path): ...@@ -111,13 +138,16 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams') params = paddle.load(path + '.pdparams')
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1 in params.keys():
if list(state_dict[k1].shape) == list(params[k2].shape): if k1 not in state_dict.keys():
new_state_dict[k1] = params[k2] logger.warning("The pretrained params {} not in model".format(k1))
else: else:
logger.warning( if list(state_dict[k1].shape) == list(params[k1].shape):
"The shape of model params {} {} not matched with loaded params {} {} !". new_state_dict[k1] = params[k1]
format(k1, state_dict[k1].shape, k2, params[k2].shape)) else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path)) logger.info("load pretrain successful from {}".format(path))
return model return model
...@@ -127,6 +157,7 @@ def save_model(model, ...@@ -127,6 +157,7 @@ def save_model(model,
optimizer, optimizer,
model_path, model_path,
logger, logger,
config,
is_best=False, is_best=False,
prefix='ppocr', prefix='ppocr',
**kwargs): **kwargs):
...@@ -135,13 +166,20 @@ def save_model(model, ...@@ -135,13 +166,20 @@ def save_model(model,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
else:
model.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config # save metric and config
with open(model_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
if is_best: if is_best:
with open(metric_prefix + '.states', 'wb') as f:
pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix)) logger.info('save best model is to {}'.format(model_prefix))
else: else:
logger.info("save model in {}".format(model_prefix)) logger.info("save model in {}".format(model_prefix))
...@@ -16,6 +16,9 @@ import logging ...@@ -16,6 +16,9 @@ import logging
import os import os
import imghdr import imghdr
import cv2 import cv2
import random
import numpy as np
import paddle
def print_dict(d, logger, delimiter=0): def print_dict(d, logger, delimiter=0):
...@@ -77,4 +80,28 @@ def check_and_read_gif(img_path): ...@@ -77,4 +80,28 @@ def check_and_read_gif(img_path):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1] imgvalue = frame[:, :, ::-1]
return imgvalue, True return imgvalue, True
return None, False return None, False
\ No newline at end of file
def load_vqa_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
def set_seed(seed=1024):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
def draw_ser_results(image,
ocr_results,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
color_map = {
idx: (color[0][idx], color[1][idx], color[2][idx])
for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert('RGB')
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
English | [简体中文](README_ch.md) English | [简体中文](README_ch.md)
# PP-Structure - [1. Introduction](#1)
- [2. Update log](#2)
- [3. Features](#3)
- [4. Results](#4)
* [4.1 Layout analysis and table recognition](#41)
* [4.2 DOC-VQA](#42)
- [5. Quick start](#5)
- [6. PP-Structure System](#6)
* [6.1 Layout analysis and table recognition](#61)
* [6.2 DOC-VQA](#62)
- [7. Model List](#7)
PP-Structure is an OCR toolkit that can be used for complex documents analysis. The main features are as follows: <a name="1"></a>
- Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
- Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
- Support to extract excel files from the table areas
- Support python whl package and command line usage, easy to use
- Support custom training for layout analysis and table structure tasks
## 1. Visualization ## 1. Introduction
<img src="../doc/table/ppstructure.GIF" width="100%"/>
PP-Structure is an OCR toolkit that can be used for document analysis and processing with complex structures, designed to help developers better complete document understanding tasks
<a name="2"></a>
## 2. Installation ## 2. Update log
* 2021.12.07 add [DOC-VQA SER and RE tasks](vqa/README.md)
### 2.1 Install requirements <a name="3"></a>
- **(1) Install PaddlePaddle** ## 3. Features
```bash The main features of PP-Structure are as follows:
pip3 install --upgrade pip
# GPU - Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple - Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
- Support to extract excel files from the table areas
# CPU - Support python whl package and command line usage, easy to use
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple - Support custom training for layout analysis and table structure tasks
- Support Document Visual Question Answering (DOC-VQA) tasks: Semantic Entity Recognition (SER) and Relation Extraction (RE)
```
For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) .
- **(2) Install Layout-Parser** <a name="4"></a>
```bash ## 4. Results
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure) <a name="41"></a>
- **(1) PIP install PaddleOCR whl package(inference only)** ### 4.1 Layout analysis and table recognition
```bash <img src="../doc/table/ppstructure.GIF" width="100%"/>
pip install "paddleocr>=2.2"
```
- **(2) Clone PaddleOCR(Inference+training)** The figure shows the pipeline of layout analysis + table recognition. The image is first divided into four areas of image, text, title and table by layout analysis, and then OCR detection and recognition is performed on the three areas of image, text and title, and the table is performed table recognition, where the image will also be stored for use.
```bash <a name="42"></a>
git clone https://github.com/PaddlePaddle/PaddleOCR
```
### 4.2 DOC-VQA
## 3. Quick Start * SER
### 3.1 Use by command line ![](./vqa/images/result_ser/zh_val_0_ser.jpg) | ![](./vqa/images/result_ser/zh_val_42_ser.jpg)
---|---
```bash Different colored boxes in the figure represent different categories. For xfun dataset, there are three categories: query, answer and header:
paddleocr --image_dir=../doc/table/1.png --type=structure
```
### 3.2 Use by python API * Dark purple: header
* Light purple: query
* Army green: answer
```python The corresponding category and OCR recognition results are also marked at the top left of the OCR detection box.
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
save_folder = './output/table' * RE
img_path = '../doc/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result: ![](./vqa/images/result_re/zh_val_21_re.jpg) | ![](./vqa/images/result_re/zh_val_40_re.jpg)
line.pop('img') ---|---
print(line)
from PIL import Image
font_path = '../doc/fonts/simfang.ttf' In the figure, the red box represents the question, the blue box represents the answer, and the question and answer are connected by green lines. The corresponding category and OCR recognition results are also marked at the top left of the OCR detection box.
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
### 3.3 Returned results format
The returned results of PP-Structure is a list composed of a dict, an example is as follows
```shell
[
{ 'type': 'Text',
'bbox': [34, 432, 345, 462],
'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
[('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
}
]
```
The description of each field in dict is as follows
| Parameter | Description | <a name="5"></a>
| --------------- | -------------|
|type|Type of image area|
|bbox|The coordinates of the image area in the original image, respectively [left upper x, left upper y, right bottom x, right bottom y]|
|res|OCR or table recognition result of image area。<br> Table: HTML string of the table; <br> OCR: A tuple containing the detection coordinates and recognition results of each single line of text|
## 5. Quick start
### 3.4 Parameter description: Start from [Quick Installation](./docs/quickstart.md)
| Parameter | Description | Default value | <a name="6"></a>
| --------------- | ---------------------------------------- | ------------------------------------------- |
| output | The path where excel and recognition results are saved | ./output/table |
| table_max_len | The long side of the image is resized in table structure model | 488 |
| table_model_dir | inference model path of table structure model | None |
| table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
Most of the parameters are consistent with the paddleocr whl package, see [doc of whl](../doc/doc_en/whl_en.md) ## 6. PP-Structure System
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image. <a name="61"></a>
## 4. PP-Structure Pipeline ### 6.1 Layout analysis and table recognition
![pipeline](../doc/table/pipeline_en.jpg)
In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR. ![pipeline](../doc/table/pipeline.jpg)
### 4.1 LayoutParser In PP-Structure, the image will be divided into 5 types of areas **text, title, image list and table**. For the first 4 types of areas, directly use PP-OCR system to complete the text detection and recognition. For the table area, after the table structuring process, the table in image is converted into an Excel file with the same table style.
Layout analysis divides the document data into regions, including the use of Python scripts for layout analysis tools, extraction of special category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md). #### 6.1.1 Layout analysis
### 4.2 Table Recognition Layout analysis classifies image by region, including the use of Python scripts of layout analysis tools, extraction of designated category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md).
Table Recognition converts table image into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed, please refer to [document](table/README.md) #### 6.1.2 Table recognition
## 5. Prediction by inference engine Table recognition converts table images into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed instructions, please refer to [document](table/README.md)
Use the following commands to complete the inference. <a name="62"></a>
```python ### 6.2 DOC-VQA
cd PaddleOCR/ppstructure
# download model Document Visual Question Answering (DOC-VQA) if a type of Visual Question Answering (VQA), which includes Semantic Entity Recognition (SER) and Relation Extraction (RE) tasks. Based on SER task, text recognition and classification in images can be completed. Based on THE RE task, we can extract the relation of the text content in the image, such as judge the problem pair. For details, please refer to [document](vqa/README.md)
mkdir inference && cd inference
# Download the detection model of the ultra-lightweight Chinese OCR model and uncompress it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
# Download the recognition model of the ultra-lightweight Chinese OCR model and uncompress it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
# Download the table structure model of the ultra-lightweight Chinese OCR model and uncompress it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
```
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
**Model List** <a name="7"></a>
|model name|description|config|model size|download| ## 7. Model List
| --- | --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction for English table scenarios|[table_mv3.yml](../configs/table/table_mv3.yml)|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
**Model List** PP-Structure系列模型列表(更新中)
LayoutParser model * Layout analysis model
|model name|description|download| |model name|description|download|
| --- | --- | --- | | --- | --- | --- |
| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet data set can be divided into 5 types of areas **text, title, table, picture and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) | | ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet dataset can divide image into 5 types of areas **text, title, table, picture, and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
| ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset can only detect tables | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
| ppyolov2_r50vd_dcn_365e_tableBank_latex | The layout analysis model trained on the TableBank Latex dataset can only detect tables | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
OCR and table recognition model
* OCR and table recognition model
|model name|description|model size|download| |model name|description|model size|download|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) | |ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | |ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) | |en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | * DOC-VQA model
|model name|description|model size|download|
| --- | --- | --- | --- |
|PP-Layout_v1.0_ser_pretrained|SER model trained on xfun Chinese dataset based on LayoutXLM|1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
|PP-Layout_v1.0_re_pretrained|RE model trained on xfun Chinese dataset based on LayoutXLM|1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` . If you need to use other models, you can download the model in [PPOCR model_list](../doc/doc_en/models_list_en.md) and [PPStructure model_list](./docs/model_list.md)
[English](README.md) | 简体中文 [English](README.md) | 简体中文
## 简介 - [1. 简介](#1)
- [2. 近期更新](#2)
- [3. 特性](#3)
- [4. 效果展示](#4)
* [4.1 版面分析和表格识别](#41)
* [4.2 DOC-VQA](#42)
- [5. 快速体验](#5)
- [6. PP-Structure 介绍](#6)
* [6.1 版面分析+表格识别](#61)
* [6.2 DOC-VQA](#62)
- [7. 模型库](#7)
<a name="1"></a>
## 1. 简介
PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,旨在帮助开发者更好的完成文档理解相关任务。 PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,旨在帮助开发者更好的完成文档理解相关任务。
## 近期更新 <a name="2"></a>
* 2021.12.07 新增VQA任务-SER和RE。
## 特性 ## 2. 近期更新
* 2021.12.07 新增DOC-[VQA任务SER和RE](vqa/README.md)
PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,主要特性如下: <a name="3"></a>
## 3. 特性
PP-Structure的主要特性如下:
- 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用) - 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用)
- 支持文字、标题、图片以及列表区域提取为文字字段(与PP-OCR联合使用) - 支持文字、标题、图片以及列表区域提取为文字字段(与PP-OCR联合使用)
- 支持表格区域进行结构化分析,最终结果输出Excel文件 - 支持表格区域进行结构化分析,最终结果输出Excel文件
...@@ -17,13 +35,22 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包 ...@@ -17,13 +35,22 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
- 支持文档视觉问答(Document Visual Question Answering,DOC-VQA)任务-语义实体识别(Semantic Entity Recognition,SER)和关系抽取(Relation Extraction,RE) - 支持文档视觉问答(Document Visual Question Answering,DOC-VQA)任务-语义实体识别(Semantic Entity Recognition,SER)和关系抽取(Relation Extraction,RE)
## 1. 效果展示 <a name="4"></a>
## 4. 效果展示
### 1.1 版面分析和表格识别 <a name="41"></a>
### 4.1 版面分析和表格识别
<img src="../doc/table/ppstructure.GIF" width="100%"/> <img src="../doc/table/ppstructure.GIF" width="100%"/>
### 1.2 VQA 图中展示了版面分析+表格识别的整体流程,图片先有版面分析划分为图像、文本、标题和表格四种区域,然后对图像、文本和标题三种区域进行OCR的检测识别,对表格进行表格识别,其中图像还会被存储下来以便使用。
<a name="42"></a>
### 4.2 DOC-VQA
* SER * SER
...@@ -46,36 +73,45 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包 ...@@ -46,36 +73,45 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。 图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
## 2. 快速体验 <a name="5"></a>
## 5. 快速体验
请参考[快速安装](./docs/quickstart.md)教程。
代码体验:从 [快速安装](./docs/quickstart.md) 开始 <a name="6"></a>
## 3. PP-Structure Pipeline介绍 ## 6. PP-Structure 介绍
### 3.1 版面分析+表格识别 <a name="61"></a>
### 6.1 版面分析+表格识别
![pipeline](../doc/table/pipeline.jpg) ![pipeline](../doc/table/pipeline.jpg)
在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。 在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。
#### 3.1.1 版面分析 #### 6.1.1 版面分析
版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md) 版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md)
#### 3.1.2 表格识别 #### 6.1.2 表格识别
表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)
表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md) <a name="62"></a>
### 6.2 DOC-VQA
### 3.2 VQA DOC-VQA指文档视觉问答,其中包括语义实体识别 (Semantic Entity Recognition, SER) 和关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair),详细说明参考[文档](vqa/README.md)
coming soon <a name="7"></a>
## 4. 模型库 ## 7. 模型库
PP-Structure系列模型列表(更新中) PP-Structure系列模型列表(更新中)
* LayoutParser 模型 * 版面分析模型
|模型名称|模型简介|下载地址| |模型名称|模型简介|下载地址|
| --- | --- | --- | | --- | --- | --- |
...@@ -90,7 +126,7 @@ PP-Structure系列模型列表(更新中) ...@@ -90,7 +126,7 @@ PP-Structure系列模型列表(更新中)
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) | |ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | |en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
* VQA模型 * DOC-VQA 模型
|模型名称|模型简介|模型大小|下载地址| |模型名称|模型简介|模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
...@@ -98,4 +134,4 @@ PP-Structure系列模型列表(更新中) ...@@ -98,4 +134,4 @@ PP-Structure系列模型列表(更新中)
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | |PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
更多模型下载,可以参考 [模型库](./docs/model_list.md) 更多模型下载,可以参考 [PPOCR model_list](../doc/doc_en/models_list.md) and [PPStructure model_list](./docs/model_list.md)
\ No newline at end of file
# Key Information Extraction(KIE)
This section provides a tutorial example on how to quickly use, train, and evaluate a key information extraction(KIE) model, [SDMGR](https://arxiv.org/abs/2103.14470), in PaddleOCR.
[SDMGR(Spatial Dual-Modality Graph Reasoning)](https://arxiv.org/abs/2103.14470) is a KIE algorithm that classifies each detected text region into predefined categories, such as order ID, invoice number, amount, and etc.
* [1. Quick Use](#1-----)
* [2. Model Training](#2-----)
* [3. Model Evaluation](#3-----)
<a name="1-----"></a>
## 1. Quick Use
[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
```shell
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
```
Download the pretrained model and predict the result:
```shell
cd PaddleOCR/
wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar && tar xf kie_vgg16.tar
python3.7 tools/infer_kie.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=kie_vgg16/best_accuracy Global.infer_img=../wildreceipt/1.txt
```
The prediction result is saved as `./output/sdmgr_kie/predicts_kie.txt`, and the visualization results are saved in the folder`/output/sdmgr_kie/kie_results/`.
The visualization results are shown in the figure below:
<div align="center">
<img src="./imgs/0.png" width="800">
</div>
<a name="2-----"></a>
## 2. Model Training
Create a softlink to the folder, `PaddleOCR/train_data`:
```shell
cd PaddleOCR/ && mkdir train_data && cd train_data
ln -s ../../wildreceipt ./
```
The configuration file used for training is `configs/kie/kie_unet_sdmgr.yml`. The default training data path in the configuration file is `train_data/wildreceipt`. After preparing the data, you can execute the model training with the following command:
```shell
python3.7 tools/train.py -c configs/kie/kie_unet_sdmgr.yml -o Global.save_model_dir=./output/kie/
```
<a name="3-----"></a>
## 3. Model Evaluation
After training, you can execute the model evaluation with the following command:
```shell
python3.7 tools/eval.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=./output/kie/best_accuracy
```
**Reference:**
<!-- [ALGORITHM] -->
```bibtex
@misc{sun2021spatial,
title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
year={2021},
eprint={2103.14470},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | |PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | |PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
## 3. KIE模型 ## 3. KIE模型
......
...@@ -39,7 +39,7 @@ paddleocr --image_dir=../doc/table/1.png --type=structure ...@@ -39,7 +39,7 @@ paddleocr --image_dir=../doc/table/1.png --type=structure
* VQA * VQA
coming soon 请参考:[文档视觉问答](../vqa/README.md)
<a name="22"></a> <a name="22"></a>
...@@ -74,7 +74,7 @@ im_show.save('result.jpg') ...@@ -74,7 +74,7 @@ im_show.save('result.jpg')
* VQA * VQA
comming soon 请参考:[文档视觉问答](../vqa/README.md)
<a name="23"></a> <a name="23"></a>
...@@ -101,7 +101,7 @@ dict 里各个字段说明如下 ...@@ -101,7 +101,7 @@ dict 里各个字段说明如下
* VQA * VQA
comming soon 请参考:[文档视觉问答](../vqa/README.md)
<a name="24"></a> <a name="24"></a>
...@@ -116,9 +116,9 @@ comming soon ...@@ -116,9 +116,9 @@ comming soon
| model_name_or_path | VQA SER模型地址 | None | | model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 | | max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt | | label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| mode | pipeline预测模式,structure: 版面分析+表格识别; vqa: ser文档信息抽取 | structure | | mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md) 大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。 运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
...@@ -133,16 +133,16 @@ cd ppstructure ...@@ -133,16 +133,16 @@ cd ppstructure
# 下载模型 # 下载模型
mkdir inference && cd inference mkdir inference && cd inference
# 下载超轻量级中文OCR模型的检测模型并解压 # 下载PP-OCRv2文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
# 下载超轻量级中文OCR模型的识别模型并解压 # 下载PP-OCRv2文本识别模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar
# 下载超轻量级英文表格英寸模型并解压 # 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \ --rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \ --image_dir=../doc/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
......
...@@ -30,7 +30,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif ...@@ -30,7 +30,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
from ppstructure.utility import parse_args, draw_structure_result from ppstructure.utility import parse_args, draw_structure_result
logger = get_logger() logger = get_logger()
...@@ -66,6 +65,7 @@ class OCRSystem(object): ...@@ -66,6 +65,7 @@ class OCRSystem(object):
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score self.drop_score = args.drop_score
elif self.mode == 'vqa': elif self.mode == 'vqa':
from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
self.vqa_engine = SerPredictor(args) self.vqa_engine = SerPredictor(args)
def __call__(self, img): def __call__(self, img):
......
...@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# run # run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`. Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
......
...@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# 执行预测 # 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
``` ```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下 运行完成后,每张图片的excel表格会保存到output字段指定的目录下
......
...@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下 我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
| 模型 | 任务 | f1 | 模型下载地址 | | 模型 | 任务 | hmean | 模型下载地址 |
|:---:|:---:|:---:| :---:| |:---:|:---:|:---:| :---:|
| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) | | LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) | | LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) | | LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
...@@ -34,7 +34,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -34,7 +34,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
### 2.1 SER ### 2.1 SER
![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg) ![](../../doc/vqa/result_ser/zh_val_0_ser.jpg) | ![](../../doc/vqa/result_ser/zh_val_42_ser.jpg)
---|--- ---|---
图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别 图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
...@@ -48,7 +48,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -48,7 +48,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
### 2.2 RE ### 2.2 RE
![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg) ![](../../doc/vqa/result_re/zh_val_21_re.jpg) | ![](../../doc/vqa/result_re/zh_val_40_re.jpg)
---|--- ---|---
...@@ -62,13 +62,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进 ...@@ -62,13 +62,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
- **(1) 安装PaddlePaddle** - **(1) 安装PaddlePaddle**
```bash ```bash
pip3 install --upgrade pip python3 -m pip install --upgrade pip
# GPU安装 # GPU安装
python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装 # CPU安装
python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
``` ```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
...@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple ...@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
- **(1)pip快速安装PaddleOCR whl包(仅预测)** - **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash ```bash
pip install paddleocr python3 -m pip install paddleocr
``` ```
- **(2)下载VQA源码(预测+训练)** - **(2)下载VQA源码(预测+训练)**
...@@ -93,21 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR ...@@ -93,21 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。 # 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
``` ```
- **(3)安装PaddleNLP** - **(3)安装VQA的`requirements`**
```bash ```bash
# 需要使用PaddleNLP最新的代码版本进行安装 python3 -m pip install -r ppstructure/vqa/requirements.txt
git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
cd PaddleNLP
pip3 install -e .
```
- **(4)安装VQA的`requirements`**
```bash
cd ppstructure/vqa
pip install -r requirements.txt
``` ```
## 4. 使用 ## 4. 使用
...@@ -115,6 +104,10 @@ pip install -r requirements.txt ...@@ -115,6 +104,10 @@ pip install -r requirements.txt
### 4.1 数据和预训练模型准备 ### 4.1 数据和预训练模型准备
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
* 下载处理好的数据集
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar) 处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)
...@@ -124,101 +117,65 @@ pip install -r requirements.txt ...@@ -124,101 +117,65 @@ pip install -r requirements.txt
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
``` ```
如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py) * 转换数据集
如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。 若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
```bash
python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
```
### 4.2 SER任务 ### 4.2 SER任务
* 启动训练 启动训练之前,需要修改下面的四个字段
1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
* 启动训练
```shell ```shell
python3.7 train_ser.py \ CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
--model_name_or_path "layoutxlm-base-uncased" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。 最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练 * 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
python3.7 train_ser.py \ CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
--model_name_or_path "model_path" \
--ser_model_type "LayoutXLM" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "./output/ser/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--evaluate_during_training \
--num_workers 8 \
--seed 2048 \
--resume
``` ```
* 评估 * 评估
```shell
export CUDA_VISIBLE_DEVICES=0
python3 eval_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--output_dir "output/ser/" \
--seed 2048
```
最终会打印出`precision`, `recall`, `f1`等指标
* 使用评估集合中提供的OCR识别结果进行预测 评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3.7 infer_ser.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--output_dir "output/ser/" \
--infer_imgs "XFUND/zh_val/image/" \
--ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt` * 使用`OCR引擎 + SER`串联预测
* 使用`OCR引擎 + SER`串联结果 使用如下命令即可完成`OCR引擎 + SER`的串联预测
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
python3.7 infer_ser_e2e.py \
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_e2e/" \
--infer_imgs "images/input/zh_val_0.jpg"
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR引擎 + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
...@@ -226,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor ...@@ -226,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
* 启动训练 * 启动训练
```shell 启动训练之前,需要修改下面的四个字段
export CUDA_VISIBLE_DEVICES=0
python3 train_re.py \
--model_name_or_path "layoutxlm-base-uncased" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 200 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048
```
* 恢复训练 1. `Train.dataset.data_dir`:指向训练集图片存放目录
2. `Train.dataset.label_file_list`:指向训练集标注文件
3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
4. `Eval.dataset.label_file_list`:指向验证集标注文件
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
python3 train_re.py \
--model_name_or_path "model_path" \
--train_data_dir "XFUND/zh_train/image" \
--train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--num_train_epochs 2 \
--eval_steps 10 \
--output_dir "output/re/" \
--learning_rate 5e-5 \
--warmup_steps 50 \
--per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \
--seed 2048 \
--resume
``` ```
最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。 最终会打印出`precision`, `recall`, `hmean`等指标。
`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
* 评估
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3 eval_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 8 \
--num_workers 8 \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标
* 评估
* 使用评估集合中提供的OCR识别结果进行预测 评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
python3 infer_re.py \
--model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--max_seq_length 512 \
--eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--label_map_path "labels/labels_ser.txt" \
--output_dir "output/re/" \
--per_gpu_eval_batch_size 1 \
--seed 2048
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标
最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt` * 使用`OCR引擎 + SER + RE`串联预测
* 使用`OCR引擎 + SER + RE`串联结果
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3.7 infer_ser_re_e2e.py \ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/
--model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
--re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
--ser_model_type "LayoutXLM" \
--max_seq_length 512 \
--output_dir "output/ser_re_e2e/" \
--infer_imgs "images/input/zh_val_21.jpg"
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
## 参考链接 ## 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from data_collator import DataCollator
from metric import re_score
from ppocr.utils.logging import get_logger
def cal_metric(re_preds, re_labels, entities):
gt_relations = []
for b in range(len(re_labels)):
rel_sent = []
for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
rel = {}
rel["head_id"] = head
rel["head"] = (entities[b]["start"][rel["head_id"]],
entities[b]["end"][rel["head_id"]])
rel["head_type"] = entities[b]["label"][rel["head_id"]]
rel["tail_id"] = tail
rel["tail"] = (entities[b]["start"][rel["tail_id"]],
entities[b]["end"][rel["tail_id"]])
rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent)
re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
return re_metrics
def evaluate(model, eval_dataloader, logger, prefix=""):
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
re_preds = []
re_labels = []
entities = []
eval_loss = 0.0
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
outputs = model(**batch)
loss = outputs['loss'].mean().item()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss))
eval_loss += loss
re_preds.extend(outputs['pred_relations'])
re_labels.extend(batch['relations'])
entities.extend(batch['entities'])
re_metrics = cal_metric(re_preds, re_labels, entities)
re_metrics = {
"precision": re_metrics["ALL"]["p"],
"recall": re_metrics["ALL"]["r"],
"f1": re_metrics["ALL"]["f1"],
}
model.train()
return re_metrics
def eval(args):
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
results = evaluate(model, eval_dataloader, logger)
logger.info("eval results: {}".format(results))
if __name__ == "__main__":
args = parse_args()
eval(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from losses import SERLoss
from vqa_utils import parse_args, get_bio_label_maps, print_arguments
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def eval(args):
logger = get_logger()
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
loss_class = SERLoss(len(label2id_map))
results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
label2id_map, id2label_map, pad_token_label_id,
logger)
logger.info(results)
def evaluate(args,
model,
tokenizer,
loss_class,
eval_dataloader,
label2id_map,
id2label_map,
pad_token_label_id,
logger,
prefix=""):
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
model.eval()
for idx, batch in enumerate(eval_dataloader):
with paddle.no_grad():
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
outputs = model(**batch)
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
loss = loss.mean()
if paddle.distributed.get_rank() == 0:
logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
idx, len(eval_dataloader), loss.numpy()[0]))
eval_loss += loss.item()
nb_eval_steps += 1
if preds is None:
preds = outputs.numpy()
out_label_ids = labels.numpy()
else:
preds = np.append(preds, outputs.numpy(), axis=0)
out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)
# label_map = {i: label.upper() for i, label in enumerate(labels)}
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]
for i in range(out_label_ids.shape[0]):
for j in range(out_label_ids.shape[1]):
if out_label_ids[i, j] != pad_token_label_id:
out_label_list[i].append(id2label_map[out_label_ids[i][j]])
preds_list[i].append(id2label_map[preds[i][j]])
results = {
"loss": eval_loss,
"precision": precision_score(out_label_list, preds_list),
"recall": recall_score(out_label_list, preds_list),
"f1": f1_score(out_label_list, preds_list),
}
with open(
os.path.join(args.output_dir, "test_gt.txt"), "w",
encoding='utf-8') as fout:
for lbl in out_label_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
with open(
os.path.join(args.output_dir, "test_pred.txt"), "w",
encoding='utf-8') as fout:
for lbl in preds_list:
for l in lbl:
fout.write(l + "\t")
fout.write("\n")
report = classification_report(out_label_list, preds_list)
logger.info("\n" + report)
logger.info("***** Eval results %s *****", prefix)
for key in sorted(results.keys()):
logger.info(" %s = %s", key, str(results[key]))
model.train()
return results, preds_list
if __name__ == "__main__":
args = parse_args()
eval(args)
...@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None): ...@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
print("===ok====") print("===ok====")
transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json") def parser_args():
import argparse
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
"--ori_gt_path", type=str, required=True, help='origin xfun gt path')
parser.add_argument(
"--output_path", type=str, required=True, help='path to save')
args = parser.parse_args()
return args
args = parser_args()
transfer_xfun_data(args.ori_gt_path, args.output_path)
export CUDA_VISIBLE_DEVICES=6
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --max_seq_length 512 \
# --output_dir "output_res_e2e/" \
# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_distributed/best_model" \
# --re_model_name_or_path "output/re_test/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e_train/" \
# --infer_imgs "images/input/zh_val_21.jpg"
# python3.7 infer_ser.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --output_dir "ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_21.jpg" \
# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
python3.7 infer_ser.py \
--model_name_or_path "output/ser_new/best_model" \
--ser_model_type "LayoutXLM" \
--output_dir "ser_new/" \
--infer_imgs "images/input/zh_val_21.jpg" \
--ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_new/best_model" \
# --ser_model_type "LayoutXLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_new/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3.7 infer_ser_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --max_seq_length 512 \
# --output_dir "output/ser_LayoutLM/" \
# --infer_imgs "images/input/zh_val_0.jpg"
# python3 infer_re.py \
# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
# --max_seq_length 512 \
# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
# --label_map_path 'labels/labels_ser.txt' \
# --output_dir "output_res" \
# --per_gpu_eval_batch_size 1 \
# --seed 2048
# python3.7 infer_ser_re_e2e.py \
# --model_name_or_path "output/ser_LayoutLM/best_model" \
# --ser_model_type "LayoutLM" \
# --re_model_name_or_path "output/re_new/best_model" \
# --max_seq_length 512 \
# --output_dir "output_ser_re_e2e/" \
# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import cv2
import matplotlib.pyplot as plt
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
from data_collator import DataCollator
from ppocr.utils.logging import get_logger
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger()
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=8,
shuffle=False,
collate_fn=DataCollator())
# 读取gt的oct数据
ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
for idx, batch in enumerate(eval_dataloader):
ocr_info = ocr_info_list[idx]
image_path = ocr_info['image_path']
ocr_info = ocr_info['ocr_info']
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
logger.info("[Infer] process: {}/{}, save result to {}".format(
idx, len(eval_dataloader), save_img_path))
with paddle.no_grad():
outputs = model(**batch)
pred_relations = outputs['pred_relations']
# 根据entity里的信息,做token解码后去过滤不要的ocr_info
ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relations in pred_relations:
for relation in relations:
if relation['tail_id'] in used_tail_id:
continue
if relation['head_id'] not in ocr_info or relation[
'tail_id'] not in ocr_info:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
img = cv2.imread(image_path)
img_show = draw_re_results(img, result)
cv2.imwrite(save_img_path, img_show)
def load_ocr(img_folder, json_path):
import json
d = []
with open(json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
image_name, info_str = line.split("\t")
info_dict = json.loads(info_str)
info_dict['image_path'] = os.path.join(img_folder, image_name)
d.append(info_dict)
return d
def filter_bg_by_txt(ocr_info, batch, tokenizer):
entities = batch['entities'][0]
input_ids = batch['input_ids'][0]
new_info_dict = {}
for i in range(len(entities['start'])):
entitie_head = entities['start'][i]
entitie_tail = entities['end'][i]
word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
txt = tokenizer.convert_ids_to_tokens(word_input_ids)
txt = tokenizer.convert_tokens_to_string(txt)
for i, info in enumerate(ocr_info):
if info['text'] == txt:
new_info_dict[i] = info
return new_info_dict
def post_process(pred_relations, ocr_info, img):
result = []
for relations in pred_relations:
for relation in relations:
ocr_info_head = ocr_info[relation['head_id']]
ocr_info_tail = ocr_info[relation['tail_id']]
result.append((ocr_info_head, ocr_info_tail))
return result
def draw_re(result, image_path, output_folder):
img = cv2.imread(image_path)
from matplotlib import pyplot as plt
for ocr_info_head, ocr_info_tail in result:
cv2.rectangle(
img,
tuple(ocr_info_head['bbox'][:2]),
tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
thickness=2)
cv2.rectangle(
img,
tuple(ocr_info_tail['bbox'][:2]),
tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
thickness=2)
center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
cv2.line(
img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
plt.imshow(img)
plt.savefig(
os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
# plt.show()
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
import paddle
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
tokenizer.padding_side)
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img,
(224, 224)).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, label_map_path):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
_, label_map = get_bio_label_maps(label_map_path)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
preds_list):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
label2id_map, _ = get_bio_label_maps(label_map_path)
for key in label2id_map:
if key.startswith("I-"):
label2id_map[key] = label2id_map["B" + key[1:]]
id2label_map = dict()
for key in label2id_map:
val = label2id_map[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[pred_id]
return ocr_info
@paddle.no_grad()
def infer(args):
os.makedirs(args.output_dir, exist_ok=True)
# init token and model
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.eval()
# load ocr results json
ocr_results = dict()
with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
for line in lines:
img_name, json_info = line.split("\t")
ocr_results[os.path.basename(img_name)] = json.loads(json_info)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(args.output_dir,
os.path.basename(img_path))
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
inputs = preprocess(
tokenizer=tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=args.max_seq_length)
if args.ser_model_type == 'LayoutLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif args.ser_model_type == 'LayoutXLM':
preds = model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds,
args.label_map_path)
ocr_info = merge_preds_list_with_ocr_info(
args.label_map_path, ocr_info, inputs["segment_offset_id"],
preds)
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": ocr_info,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, ocr_info)
cv2.imwrite(save_img_path, img_res)
return
if __name__ == "__main__":
args = parse_args()
infer(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
def parse_ocr_info_for_ser(ocr_result):
ocr_info = []
for res in ocr_result:
ocr_info.append({
"text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
"poly": res[0],
})
return ocr_info
class SerPredictor(object):
def __init__(self, args):
self.args = args
self.max_seq_length = args.max_seq_length
# init ser token and model
tokenizer_class, base_model_class, model_class = MODELS[
args.ser_model_type]
self.tokenizer = tokenizer_class.from_pretrained(
args.model_name_or_path)
self.model = model_class.from_pretrained(args.model_name_or_path)
self.model.eval()
# init ocr_engine
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(
rec_model_dir=args.rec_model_dir,
det_model_dir=args.det_model_dir,
use_angle_cls=False,
show_log=False)
# init dict
label2id_map, self.id2label_map = get_bio_label_maps(
args.label_map_path)
self.label2id_map_for_draw = dict()
for key in label2id_map:
if key.startswith("I-"):
self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
else:
self.label2id_map_for_draw[key] = label2id_map[key]
def __call__(self, img):
ocr_result = self.ocr_engine.ocr(img, cls=False)
ocr_info = parse_ocr_info_for_ser(ocr_result)
inputs = preprocess(
tokenizer=self.tokenizer,
ori_img=img,
ocr_info=ocr_info,
max_seq_len=self.max_seq_length)
if self.args.ser_model_type == 'LayoutLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
elif self.args.ser_model_type == 'LayoutXLM':
preds = self.model(
input_ids=inputs["input_ids"],
bbox=inputs["bbox"],
image=inputs["image"],
token_type_ids=inputs["token_type_ids"],
attention_mask=inputs["attention_mask"])
preds = preds[0]
preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
ocr_info = merge_preds_list_with_ocr_info(
ocr_info, inputs["segment_offset_id"], preds,
self.label2id_map_for_draw)
return ocr_info, inputs
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_engine = SerPredictor(args)
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
result, _ = ser_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"ser_resule": result,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img, result)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import json
import cv2
import numpy as np
from copy import deepcopy
from PIL import Image
import paddle
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
# relative reference
from vqa_utils import parse_args, get_image_file_list, draw_re_results
from infer_ser_e2e import SerPredictor
def make_input(ser_input, ser_result, max_seq_len=512):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
entities = ser_input['entities'][0]
assert len(entities) == len(ser_result)
# entities
start = []
end = []
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_result, entities)):
if res['pred'] == 'O':
continue
entity_idx_dict[len(start)] = i
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
# relations
head = []
tail = []
for i in range(len(entities["label"])):
for j in range(len(entities["label"])):
if entities["label"][i] == 1 and entities["label"][j] == 2:
head.append(i)
tail.append(j)
relations = dict(head=head, tail=tail)
batch_size = ser_input["input_ids"].shape[0]
entities_batch = []
relations_batch = []
for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
ser_input['entities'] = entities_batch
ser_input['relations'] = relations_batch
ser_input.pop('segment_offset_id')
return ser_input, entity_idx_dict
class SerReSystem(object):
def __init__(self, args):
self.ser_engine = SerPredictor(args)
self.tokenizer = LayoutXLMTokenizer.from_pretrained(
args.re_model_name_or_path)
self.model = LayoutXLMForRelationExtraction.from_pretrained(
args.re_model_name_or_path)
self.model.eval()
def __call__(self, img):
ser_result, ser_inputs = self.ser_engine(img)
re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
re_result = self.model(**re_input)
pred_relations = re_result['pred_relations'][0]
# 进行 relations 到 ocr信息的转换
result = []
used_tail_id = []
for relation in pred_relations:
if relation['tail_id'] in used_tail_id:
continue
used_tail_id.append(relation['tail_id'])
ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
result.append((ocr_info_head, ocr_info_tail))
return result
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
# get infer img list
infer_imgs = get_image_file_list(args.infer_imgs)
# loop for infer
ser_re_engine = SerReSystem(args)
with open(
os.path.join(args.output_dir, "infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
args.output_dir,
os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
print("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
img = cv2.imread(img_path)
result = ser_re_engine(img)
fout.write(img_path + "\t" + json.dumps(
{
"result": result,
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img, result)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import numpy as np
import logging
logger = logging.getLogger(__name__)
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
def get_last_checkpoint(folder):
content = os.listdir(folder)
checkpoints = [
path for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(
os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(
folder,
max(checkpoints,
key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
def re_score(pred_relations, gt_relations, mode="strict"):
"""Evaluate RE predictions
Args:
pred_relations (list) : list of list of predicted relations (several relations in each sentence)
gt_relations (list) : list of list of ground truth relations
rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
"tail": (start_idx (inclusive), end_idx (exclusive)),
"head_type": ent_type,
"tail_type": ent_type,
"type": rel_type}
vocab (Vocab) : dataset vocabulary
mode (str) : in 'strict' or 'boundaries'"""
assert mode in ["strict", "boundaries"]
relation_types = [v for v in [0, 1] if not v == 0]
scores = {
rel: {
"tp": 0,
"fp": 0,
"fn": 0
}
for rel in relation_types + ["ALL"]
}
# Count GT relations and Predicted relations
n_sents = len(gt_relations)
n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
# Count TP, FP and FN per type
for pred_sent, gt_sent in zip(pred_relations, gt_relations):
for rel_type in relation_types:
# strict mode takes argument types into account
if mode == "strict":
pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
rel["tail_type"])
for rel in gt_sent if rel["type"] == rel_type}
# boundaries mode only takes argument spans into account
elif mode == "boundaries":
pred_rels = {(rel["head"], rel["tail"])
for rel in pred_sent if rel["type"] == rel_type}
gt_rels = {(rel["head"], rel["tail"])
for rel in gt_sent if rel["type"] == rel_type}
scores[rel_type]["tp"] += len(pred_rels & gt_rels)
scores[rel_type]["fp"] += len(pred_rels - gt_rels)
scores[rel_type]["fn"] += len(gt_rels - pred_rels)
# Compute per entity Precision / Recall / F1
for rel_type in scores.keys():
if scores[rel_type]["tp"]:
scores[rel_type]["p"] = scores[rel_type]["tp"] / (
scores[rel_type]["fp"] + scores[rel_type]["tp"])
scores[rel_type]["r"] = scores[rel_type]["tp"] / (
scores[rel_type]["fn"] + scores[rel_type]["tp"])
else:
scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
scores[rel_type]["f1"] = (
2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
(scores[rel_type]["p"] + scores[rel_type]["r"]))
else:
scores[rel_type]["f1"] = 0
# Compute micro F1 Scores
tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
if tp:
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)
else:
precision, recall, f1 = 0, 0, 0
scores["ALL"]["p"] = precision
scores["ALL"]["r"] = recall
scores["ALL"]["f1"] = f1
scores["ALL"]["tp"] = tp
scores["ALL"]["fp"] = fp
scores["ALL"]["fn"] = fn
# Compute Macro F1 Scores
scores["ALL"]["Macro_f1"] = np.mean(
[scores[ent_type]["f1"] for ent_type in relation_types])
scores["ALL"]["Macro_p"] = np.mean(
[scores[ent_type]["p"] for ent_type in relation_types])
scores["ALL"]["Macro_r"] = np.mean(
[scores[ent_type]["r"] for ent_type in relation_types])
# logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
# logger.info(
# "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
# n_sents, n_rels, n_found, tp
# )
# )
# logger.info(
# "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
# )
# logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
# logger.info(
# "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
# scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
# )
# )
# for rel_type in relation_types:
# logger.info(
# "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
# rel_type,
# scores[rel_type]["tp"],
# scores[rel_type]["fp"],
# scores[rel_type]["fn"],
# scores[rel_type]["p"],
# scores[rel_type]["r"],
# scores[rel_type]["f1"],
# scores[rel_type]["tp"] + scores[rel_type]["fp"],
# )
# )
return scores
sentencepiece sentencepiece
yacs yacs
seqeval seqeval
\ No newline at end of file paddlenlp>=2.2.1
\ No newline at end of file
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import numpy as np
import paddle
from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from data_collator import DataCollator
from eval_re import evaluate
from ppocr.utils.logging import get_logger
def train(args):
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
print_arguments(args, logger)
# Added here for reproducibility (even between python 2 and 3)
set_seed(args.seed)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
# dist mode
if distributed:
paddle.distributed.init_parallel_env()
tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
if not args.resume:
model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
model = LayoutXLMForRelationExtraction(model, dropout=None)
logger.info('train from scratch')
else:
logger.info('resume from {}'.format(args.model_name_or_path))
model = LayoutXLMForRelationExtraction.from_pretrained(
args.model_name_or_path)
# dist mode
if distributed:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
max_seq_len=args.max_seq_length,
pad_token_label_id=pad_token_label_id,
contains_re=True,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
shuffle=False,
collate_fn=DataCollator())
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
optimizer = paddle.optimizer.Adam(
learning_rate=args.learning_rate,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
grad_clip=grad_clip,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = {}".format(len(train_dataset)))
logger.info(" Num Epochs = {}".format(args.num_train_epochs))
logger.info(" Instantaneous batch size per GPU = {}".format(
args.per_gpu_train_batch_size))
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = {}".
format(args.per_gpu_train_batch_size *
paddle.distributed.get_world_size()))
logger.info(" Total optimization steps = {}".format(t_total))
global_step = 0
model.clear_gradients()
train_dataloader_len = len(train_dataloader)
best_metirc = {'f1': 0}
model.train()
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
print_step = 1
for epoch in range(int(args.num_train_epochs)):
for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start
train_start = time.time()
outputs = model(**batch)
train_run_cost += time.time() - train_start
# model outputs are always tuple in ppnlp (see doc)
loss = outputs['loss']
loss = loss.mean()
loss.backward()
optimizer.step()
optimizer.clear_grad()
# lr_scheduler.step() # Update learning rate schedule
global_step += 1
total_samples += batch['image'].shape[0]
if rank == 0 and step % print_step == 0:
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch, args.num_train_epochs, step,
train_dataloader_len, global_step,
np.mean(loss.numpy()),
optimizer.get_lr(), train_reader_cost / print_step, (
train_reader_cost + train_run_cost) / print_step,
total_samples / print_step, total_samples / (
train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(model, eval_dataloader, logger)
if results['f1'] >= best_metirc['f1']:
best_metirc = results
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("eval results: {}".format(results))
logger.info("best_metirc: {}".format(best_metirc))
reader_start = time.time()
if rank == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(output_dir))
logger.info("best_metirc: {}".format(best_metirc))
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=True)
train(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import random
import time
import copy
import logging
import argparse
import paddle
import numpy as np
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
from xfun import XFUNDataset
from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
from eval_ser import evaluate
from losses import SERLoss
from ppocr.utils.logging import get_logger
MODELS = {
'LayoutXLM':
(LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
'LayoutLM':
(LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
}
def train(args):
os.makedirs(args.output_dir, exist_ok=True)
rank = paddle.distributed.get_rank()
distributed = paddle.distributed.get_world_size() > 1
logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
print_arguments(args, logger)
label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
loss_class = SERLoss(len(label2id_map))
pad_token_label_id = loss_class.ignore_index
# dist mode
if distributed:
paddle.distributed.init_parallel_env()
tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
if not args.resume:
base_model = base_model_class.from_pretrained(args.model_name_or_path)
model = model_class(
base_model, num_classes=len(label2id_map), dropout=None)
logger.info('train from scratch')
else:
logger.info('resume from {}'.format(args.model_name_or_path))
model = model_class.from_pretrained(args.model_name_or_path)
# dist mode
if distributed:
model = paddle.DataParallel(model)
train_dataset = XFUNDataset(
tokenizer,
data_dir=args.train_data_dir,
label_path=args.train_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
eval_dataset = XFUNDataset(
tokenizer,
data_dir=args.eval_data_dir,
label_path=args.eval_label_path,
label2id_map=label2id_map,
img_size=(224, 224),
pad_token_label_id=pad_token_label_id,
contains_re=False,
add_special_ids=False,
return_attention_mask=True,
load_mode='all')
train_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=args.num_workers,
use_shared_memory=True,
collate_fn=None, )
t_total = len(train_dataloader) * args.num_train_epochs
# build linear decay with warmup lr sch
lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
learning_rate=args.learning_rate,
decay_steps=t_total,
end_lr=0.0,
power=1.0)
if args.warmup_steps > 0:
lr_scheduler = paddle.optimizer.lr.LinearWarmup(
lr_scheduler,
args.warmup_steps,
start_lr=0,
end_lr=args.learning_rate, )
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler,
parameters=model.parameters(),
epsilon=args.adam_epsilon,
weight_decay=args.weight_decay)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d",
args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed) = %d",
args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss = 0.0
set_seed(args.seed)
best_metrics = None
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
print_step = 1
model.train()
for epoch_id in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - reader_start
if args.ser_model_type == 'LayoutLM':
if 'image' in batch:
batch.pop('image')
labels = batch.pop('labels')
train_start = time.time()
outputs = model(**batch)
train_run_cost += time.time() - train_start
if args.ser_model_type == 'LayoutXLM':
outputs = outputs[0]
loss = loss_class(labels, outputs, batch['attention_mask'])
# model outputs are always tuple in ppnlp (see doc)
loss = loss.mean()
loss.backward()
tr_loss += loss.item()
optimizer.step()
lr_scheduler.step() # Update learning rate schedule
optimizer.clear_grad()
global_step += 1
total_samples += batch['input_ids'].shape[0]
if rank == 0 and step % print_step == 0:
logger.info(
"epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
format(epoch_id, args.num_train_epochs, step,
len(train_dataloader), global_step,
loss.numpy()[0],
lr_scheduler.get_lr(), train_reader_cost /
print_step, (train_reader_cost + train_run_cost) /
print_step, total_samples / print_step, total_samples
/ (train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
# Log metrics
# Only evaluate when single GPU otherwise metrics may not average well
results, _ = evaluate(args, model, tokenizer, loss_class,
eval_dataloader, label2id_map,
id2label_map, pad_token_label_id, logger)
if best_metrics is None or results["f1"] >= best_metrics["f1"]:
best_metrics = copy.deepcopy(results)
output_dir = os.path.join(args.output_dir, "best_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args,
os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(
output_dir))
logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
epoch_id, args.num_train_epochs, step,
len(train_dataloader), results))
if best_metrics is not None:
logger.info("best metrics: {}".format(best_metrics))
reader_start = time.time()
if rank == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, "latest_model")
os.makedirs(output_dir, exist_ok=True)
if distributed:
model._layers.save_pretrained(output_dir)
else:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
paddle.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to {}".format(output_dir))
return global_step, tr_loss / global_step
if __name__ == "__main__":
args = parse_args()
train(args)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import cv2
import random
import numpy as np
import imghdr
from copy import deepcopy
import paddle
from PIL import Image, ImageDraw, ImageFont
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def get_bio_label_maps(label_map_path):
with open(label_map_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
lines = [line.strip() for line in lines]
if "O" not in lines:
lines.insert(0, "O")
labels = []
for line in lines:
if line == "O":
labels.append("O")
else:
labels.append("B-" + line)
labels.append("I-" + line)
label2id_map = {label: idx for idx, label in enumerate(labels)}
id2label_map = {idx: label for idx, label in enumerate(labels)}
return label2id_map, id2label_map
def get_image_file_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
def draw_ser_results(image,
ocr_results,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(2021)
color = (np.random.permutation(range(255)),
np.random.permutation(range(255)),
np.random.permutation(range(255)))
color_map = {
idx: (color[0][idx], color[1][idx], color[2][idx])
for idx in range(1, 255)
}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)
# draw ocr results
start_y = max(0, bbox[0][1] - font_size)
tw = font.getsize(text)[0]
draw.rectangle(
[(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
def draw_re_results(image,
result,
font_path="../../doc/fonts/simfang.ttf",
font_size=18):
np.random.seed(0)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
color_head = (0, 0, 255)
color_tail = (255, 0, 0)
color_line = (0, 255, 0)
for ocr_info_head, ocr_info_tail in result:
draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
font_size, color_head)
draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
font_size, color_tail)
center_head = (
(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
(ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
center_tail = (
(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
(ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
draw.line([center_head, center_tail], fill=color_line, width=5)
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
# pad sentences
def pad_sentences(tokenizer,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding with larger size, reshape is carried out
max_seq_len = (
len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [tokenizer.pad_token_id] * difference
encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
] * difference
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def split_page(encoded_inputs, max_seq_len=512):
"""
truncate is often used in training process
"""
for key in encoded_inputs:
if key == 'entities':
encoded_inputs[key] = [encoded_inputs[key]]
continue
encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
else: # for bbox
encoded_inputs[key] = encoded_inputs[key].reshape(
[-1, max_seq_len, 4])
return encoded_inputs
def preprocess(
tokenizer,
ori_img,
ocr_info,
img_size=(224, 224),
pad_token_label_id=-100,
max_seq_len=512,
add_special_ids=False,
return_attention_mask=True, ):
ocr_info = deepcopy(ocr_info)
height = ori_img.shape[0]
width = ori_img.shape[1]
img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
segment_offset_id = []
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
entities = []
for info in ocr_info:
# x1, y1, x2, y2
bbox = info["bbox"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
"entities": entities
}
encoded_inputs = pad_sentences(
tokenizer,
encoded_inputs,
max_seq_len=max_seq_len,
return_attention_mask=return_attention_mask)
encoded_inputs = split_page(encoded_inputs)
fake_bs = encoded_inputs["input_ids"].shape[0]
encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
[fake_bs] + list(img.shape))
encoded_inputs["segment_offset_id"] = segment_offset_id
return encoded_inputs
def postprocess(attention_mask, preds, id2label_map):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds = np.argmax(preds, axis=2)
preds_list = [[] for _ in range(preds.shape[0])]
# keep batch info
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
if attention_mask[i][j] == 1:
preds_list[i].append(id2label_map[preds[i][j]])
return preds_list
def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
label2id_map_for_draw):
# must ensure the preds_list is generated from the same image
preds = [p for pred in preds_list for p in pred]
id2label_map = dict()
for key in label2id_map_for_draw:
val = label2id_map_for_draw[key]
if key == "O":
id2label_map[val] = key
if key.startswith("B-") or key.startswith("I-"):
id2label_map[val] = key[2:]
else:
id2label_map[val] = key
for idx in range(len(segment_offset_id)):
if idx == 0:
start_id = 0
else:
start_id = segment_offset_id[idx - 1]
end_id = segment_offset_id[idx]
curr_pred = preds[start_id:end_id]
curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
if len(curr_pred) <= 0:
pred_id = 0
else:
counts = np.bincount(curr_pred)
pred_id = np.argmax(counts)
ocr_info[idx]["pred_id"] = int(pred_id)
ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
return ocr_info
def print_arguments(args, logger=None):
print_func = logger.info if logger is not None else print
"""print arguments"""
print_func('----------- Configuration Arguments -----------')
for arg, value in sorted(vars(args).items()):
print_func('%s: %s' % (arg, value))
print_func('------------------------------------------------')
def parse_args():
parser = argparse.ArgumentParser()
# Required parameters
# yapf: disable
parser.add_argument("--model_name_or_path",
default=None, type=str, required=True,)
parser.add_argument("--ser_model_type",
default='LayoutXLM', type=str)
parser.add_argument("--re_model_name_or_path",
default=None, type=str, required=False,)
parser.add_argument("--train_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--train_label_path", default=None,
type=str, required=False,)
parser.add_argument("--eval_data_dir", default=None,
type=str, required=False,)
parser.add_argument("--eval_label_path", default=None,
type=str, required=False,)
parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",)
parser.add_argument("--num_workers", default=8, type=int,)
parser.add_argument("--per_gpu_train_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--per_gpu_eval_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for eval.",)
parser.add_argument("--learning_rate", default=5e-5,
type=float, help="The initial learning rate for Adam.",)
parser.add_argument("--weight_decay", default=0.0,
type=float, help="Weight decay if we apply some.",)
parser.add_argument("--adam_epsilon", default=1e-8,
type=float, help="Epsilon for Adam optimizer.",)
parser.add_argument("--max_grad_norm", default=1.0,
type=float, help="Max gradient norm.",)
parser.add_argument("--num_train_epochs", default=3, type=int,
help="Total number of training epochs to perform.",)
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.",)
parser.add_argument("--eval_steps", type=int, default=10,
help="eval every X updates steps.",)
parser.add_argument("--seed", type=int, default=2048,
help="random seed for initialization",)
parser.add_argument("--rec_model_dir", default=None, type=str, )
parser.add_argument("--det_model_dir", default=None, type=str, )
parser.add_argument(
"--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
parser.add_argument("--infer_imgs", default=None, type=str, required=False)
parser.add_argument("--resume", action='store_true')
parser.add_argument("--ocr_json_path", default=None,
type=str, required=False, help="ocr prediction results")
# yapf: enable
args = parser.parse_args()
return args
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import cv2
import numpy as np
import paddle
import copy
from paddle.io import Dataset
__all__ = ["XFUNDataset"]
class XFUNDataset(Dataset):
"""
Example:
print("=====begin to build dataset=====")
from paddlenlp.transformers import LayoutXLMTokenizer
tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
tok_res = tokenizer.tokenize("Maribyrnong")
# res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
dataset = XfunDatasetForSer(
tokenizer,
data_dir="./zh.val/",
label_path="zh.val/xfun_normalize_val.json",
img_size=(224,224))
print(len(dataset))
data = dataset[0]
print(data.keys())
print("input_ids: ", data["input_ids"])
print("labels: ", data["labels"])
print("token_type_ids: ", data["token_type_ids"])
print("words_list: ", data["words_list"])
print("image shape: ", data["image"].shape)
"""
def __init__(self,
tokenizer,
data_dir,
label_path,
contains_re=False,
label2id_map=None,
img_size=(224, 224),
pad_token_label_id=None,
add_special_ids=False,
return_attention_mask=True,
load_mode='all',
max_seq_len=512):
super().__init__()
self.tokenizer = tokenizer
self.data_dir = data_dir
self.label_path = label_path
self.contains_re = contains_re
self.label2id_map = label2id_map
self.img_size = img_size
self.pad_token_label_id = pad_token_label_id
self.add_special_ids = add_special_ids
self.return_attention_mask = return_attention_mask
self.load_mode = load_mode
self.max_seq_len = max_seq_len
if self.pad_token_label_id is None:
self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
self.all_lines = self.read_all_lines()
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
self.return_keys = {
'bbox': {
'type': 'np',
'dtype': 'int64'
},
'input_ids': {
'type': 'np',
'dtype': 'int64'
},
'labels': {
'type': 'np',
'dtype': 'int64'
},
'attention_mask': {
'type': 'np',
'dtype': 'int64'
},
'image': {
'type': 'np',
'dtype': 'float32'
},
'token_type_ids': {
'type': 'np',
'dtype': 'int64'
},
'entities': {
'type': 'dict'
},
'relations': {
'type': 'dict'
}
}
if load_mode == "all":
self.encoded_inputs_all = self._parse_label_file_all()
def pad_sentences(self,
encoded_inputs,
max_seq_len=512,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_token_type_ids=True,
truncation_strategy="longest_first",
return_overflowing_tokens=False,
return_special_tokens_mask=False):
# Padding
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
if self.tokenizer.padding_side == 'right':
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"]) + [0] * difference
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
encoded_inputs["token_type_ids"] +
[self.tokenizer.pad_token_type_id] * difference)
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = encoded_inputs[
"special_tokens_mask"] + [1] * difference
encoded_inputs["input_ids"] = encoded_inputs[
"input_ids"] + [self.tokenizer.pad_token_id] * difference
encoded_inputs["labels"] = encoded_inputs[
"labels"] + [self.pad_token_label_id] * difference
encoded_inputs["bbox"] = encoded_inputs[
"bbox"] + [[0, 0, 0, 0]] * difference
elif self.tokenizer.padding_side == 'left':
if return_attention_mask:
encoded_inputs["attention_mask"] = [0] * difference + [
1
] * len(encoded_inputs["input_ids"])
if return_token_type_ids:
encoded_inputs["token_type_ids"] = (
[self.tokenizer.pad_token_type_id] * difference +
encoded_inputs["token_type_ids"])
if return_special_tokens_mask:
encoded_inputs["special_tokens_mask"] = [
1
] * difference + encoded_inputs["special_tokens_mask"]
encoded_inputs["input_ids"] = [
self.tokenizer.pad_token_id
] * difference + encoded_inputs["input_ids"]
encoded_inputs["labels"] = [
self.pad_token_label_id
] * difference + encoded_inputs["labels"]
encoded_inputs["bbox"] = [
[0, 0, 0, 0]
] * difference + encoded_inputs["bbox"]
else:
if return_attention_mask:
encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
"input_ids"])
return encoded_inputs
def truncate_inputs(self, encoded_inputs, max_seq_len=512):
for key in encoded_inputs:
if key == "sample_id":
continue
length = min(len(encoded_inputs[key]), max_seq_len)
encoded_inputs[key] = encoded_inputs[key][:length]
return encoded_inputs
def read_all_lines(self, ):
with open(self.label_path, "r", encoding='utf-8') as fin:
lines = fin.readlines()
return lines
def _parse_label_file_all(self):
"""
parse all samples
"""
encoded_inputs_all = []
for line in self.all_lines:
encoded_inputs_all.extend(self._parse_label_file(line))
return encoded_inputs_all
def _parse_label_file(self, line):
"""
parse single sample
"""
image_name, info_str = line.split("\t")
image_path = os.path.join(self.data_dir, image_name)
def add_imgge_path(x):
x['image_path'] = image_path
return x
encoded_inputs = self._read_encoded_inputs_sample(info_str)
if self.contains_re:
encoded_inputs = self._chunk_re(encoded_inputs)
else:
encoded_inputs = self._chunk_ser(encoded_inputs)
encoded_inputs = list(map(add_imgge_path, encoded_inputs))
return encoded_inputs
def _read_encoded_inputs_sample(self, info_str):
"""
parse label info
"""
# read text info
info_dict = json.loads(info_str)
height = info_dict["height"]
width = info_dict["width"]
words_list = []
bbox_list = []
input_ids_list = []
token_type_ids_list = []
gt_label_list = []
if self.contains_re:
# for re
entities = []
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
for info in info_dict["ocr_info"]:
if self.contains_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
continue
id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]])
# x1, y1, x2, y2
bbox = info["bbox"]
label = info["label"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
gt_label = []
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
if label.lower() == "other":
gt_label.extend([0] * len(encode_res["input_ids"]))
else:
gt_label.append(self.label2id_map[("b-" + label).upper()])
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1))
if self.contains_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
gt_label_list.extend(gt_label)
words_list.append(text)
encoded_inputs = {
"input_ids": input_ids_list,
"labels": gt_label_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
# "words_list": words_list,
}
encoded_inputs = self.pad_sentences(
encoded_inputs,
max_seq_len=self.max_seq_len,
return_attention_mask=self.return_attention_mask)
encoded_inputs = self.truncate_inputs(encoded_inputs)
if self.contains_re:
relations = self._relations(entities, relations, id2label,
empty_entity, entity_id_to_index_map)
encoded_inputs['relations'] = relations
encoded_inputs['entities'] = entities
return encoded_inputs
def _chunk_ser(self, encoded_inputs):
encoded_inputs_all = []
seq_len = len(encoded_inputs['input_ids'])
chunk_size = 512
for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
chunk_beg = index
chunk_end = min(index + chunk_size, seq_len)
encoded_inputs_example = {}
for key in encoded_inputs:
encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
chunk_end]
encoded_inputs_all.append(encoded_inputs_example)
return encoded_inputs_all
def _chunk_re(self, encoded_inputs):
# prepare data
entities = encoded_inputs.pop('entities')
relations = encoded_inputs.pop('relations')
encoded_inputs_all = []
chunk_size = 512
for chunk_id, index in enumerate(
range(0, len(encoded_inputs["input_ids"]), chunk_size)):
item = {}
for k in encoded_inputs:
item[k] = encoded_inputs[k][index:index + chunk_size]
# select entity in current chunk
entities_in_this_span = []
global_to_local_map = {} #
for entity_id, entity in enumerate(entities):
if (index <= entity["start"] < index + chunk_size and
index <= entity["end"] < index + chunk_size):
entity["start"] = entity["start"] - index
entity["end"] = entity["end"] - index
global_to_local_map[entity_id] = len(entities_in_this_span)
entities_in_this_span.append(entity)
# select relations in current chunk
relations_in_this_span = []
for relation in relations:
if (index <= relation["start_index"] < index + chunk_size and
index <= relation["end_index"] < index + chunk_size):
relations_in_this_span.append({
"head": global_to_local_map[relation["head"]],
"tail": global_to_local_map[relation["tail"]],
"start_index": relation["start_index"] - index,
"end_index": relation["end_index"] - index,
})
item.update({
"entities": reformat(entities_in_this_span),
"relations": reformat(relations_in_this_span),
})
item['entities']['label'] = [
self.entities_labels[x] for x in item['entities']['label']
]
encoded_inputs_all.append(item)
return encoded_inputs_all
def _relations(self, entities, relations, id2label, empty_entity,
entity_id_to_index_map):
"""
build relations
"""
relations = list(set(relations))
relations = [
rel for rel in relations
if rel[0] not in empty_entity and rel[1] not in empty_entity
]
kv_relations = []
for rel in relations:
pair = [id2label[rel[0]], id2label[rel[1]]]
if pair == ["question", "answer"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[0]],
"tail": entity_id_to_index_map[rel[1]]
})
elif pair == ["answer", "question"]:
kv_relations.append({
"head": entity_id_to_index_map[rel[1]],
"tail": entity_id_to_index_map[rel[0]]
})
else:
continue
relations = sorted(
[{
"head": rel["head"],
"tail": rel["tail"],
"start_index": get_relation_span(rel, entities)[0],
"end_index": get_relation_span(rel, entities)[1],
} for rel in kv_relations],
key=lambda x: x["head"], )
return relations
def load_img(self, image_path):
# read img
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
resize_h, resize_w = self.img_size
im_shape = img.shape[0:2]
im_scale_y = resize_h / im_shape[0]
im_scale_x = resize_w / im_shape[1]
img_new = cv2.resize(
img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
img_new = img_new / 255.0
img_new -= mean
img_new /= std
img = img_new.transpose((2, 0, 1))
return img
def __getitem__(self, idx):
if self.load_mode == "all":
data = copy.deepcopy(self.encoded_inputs_all[idx])
else:
data = self._parse_label_file(self.all_lines[idx])[0]
image_path = data.pop('image_path')
data["image"] = self.load_img(image_path)
return_data = {}
for k, v in data.items():
if k in self.return_keys:
if self.return_keys[k]['type'] == 'np':
v = np.array(v, dtype=self.return_keys[k]['dtype'])
return_data[k] = v
return return_data
def __len__(self, ):
if self.load_mode == "all":
return len(self.encoded_inputs_all)
else:
return len(self.all_lines)
def get_relation_span(rel, entities):
bound = []
for entity_index in [rel["head"], rel["tail"]]:
bound.append(entities[entity_index]["start"])
bound.append(entities[entity_index]["end"])
return min(bound), max(bound)
def reformat(data):
new_data = {}
for item in data:
for k, v in item.items():
if k not in new_data:
new_data[k] = []
new_data[k].append(v)
return new_data
...@@ -13,4 +13,3 @@ lxml ...@@ -13,4 +13,3 @@ lxml
premailer premailer
openpyxl openpyxl
fasttext==0.9.1 fasttext==0.9.1
...@@ -239,8 +239,7 @@ fi ...@@ -239,8 +239,7 @@ fi
if [ ${MODE} = "klquant_whole_infer" ]; then if [ ${MODE} = "klquant_whole_infer" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
ln -s ./icdar2015_lite ./icdar2015 && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
...@@ -249,6 +248,8 @@ if [ ${MODE} = "klquant_whole_infer" ]; then ...@@ -249,6 +248,8 @@ if [ ${MODE} = "klquant_whole_infer" ]; then
if [ ${model_name} = "PPOCRv2_ocr_rec_kl" ]; then if [ ${model_name} = "PPOCRv2_ocr_rec_kl" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data.tar && cd ../
cd ./inference && tar xf rec_inference.tar && tar xf ch_PP-OCRv2_rec_infer.tar && cd ../ cd ./inference && tar xf rec_inference.tar && tar xf ch_PP-OCRv2_rec_infer.tar && cd ../
fi fi
if [ ${model_name} = "PPOCRv2_ocr_det_kl" ]; then if [ ${model_name} = "PPOCRv2_ocr_det_kl" ]; then
......
...@@ -68,14 +68,14 @@ test_tipc/ ...@@ -68,14 +68,14 @@ test_tipc/
├── model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt # 测试Linux上c++预测的配置文件 ├── model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt # 测试Linux上c++预测的配置文件
├── model_linux_gpu_normal_normal_infer_python_jetson.txt # 测试Jetson上python预测的配置文件 ├── model_linux_gpu_normal_normal_infer_python_jetson.txt # 测试Jetson上python预测的配置文件
├── train_linux_gpu_fleet_amp_infer_python_linux_gpu_cpu.txt # 测试Linux上多机多卡、混合精度训练和python预测的配置文件 ├── train_linux_gpu_fleet_amp_infer_python_linux_gpu_cpu.txt # 测试Linux上多机多卡、混合精度训练和python预测的配置文件
├── ... ├── ...
├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录 ├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录
├── ... ├── ...
├── ch_ppocr_mobile_v2.0_rec # ch_ppocr_mobile_v2.0_rec模型的测试配置文件目录 ├── ch_ppocr_mobile_v2.0_rec # ch_ppocr_mobile_v2.0_rec模型的测试配置文件目录
├── ... ├── ...
├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录 ├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录
├── ... ├── ...
├── ... ├── ...
├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对 ├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对
├── python_ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型python预测fp32精度的结果 ├── python_ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型python预测fp32精度的结果
├── python_ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型python预测fp16精度的结果 ├── python_ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型python预测fp16精度的结果
...@@ -119,7 +119,7 @@ bash test_tipc/test_train_inference_python.sh configs/[model_name]/[params_file_ ...@@ -119,7 +119,7 @@ bash test_tipc/test_train_inference_python.sh configs/[model_name]/[params_file_
bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer' bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
# 运行测试 # 运行测试
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer' bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
``` ```
关于本示例命令的更多信息可查看[基础训练预测使用文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/test_tipc/docs/test_train_inference_python.md#22-%E5%8A%9F%E8%83%BD%E6%B5%8B%E8%AF%95) 关于本示例命令的更多信息可查看[基础训练预测使用文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/test_tipc/docs/test_train_inference_python.md#22-%E5%8A%9F%E8%83%BD%E6%B5%8B%E8%AF%95)
### 配置文件命名规范 ### 配置文件命名规范
...@@ -136,9 +136,9 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobil ...@@ -136,9 +136,9 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobil
<a name="more"></a> <a name="more"></a>
## 4. 开始测试 ## 4. 开始测试
各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程: 各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
- [test_train_inference_python 使用](docs/test_train_inference_python.md) :测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。 - [test_train_inference_python 使用](docs/test_train_inference_python.md) :测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
- [test_inference_cpp 使用](docs/test_inference_cpp.md):测试基于C++的模型推理。 - [test_inference_cpp 使用](docs/test_inference_cpp.md):测试基于C++的模型推理。
- [test_serving 使用](docs/test_serving.md):测试基于Paddle Serving的服务化部署功能。 - [test_serving 使用](docs/test_serving.md):测试基于Paddle Serving的服务化部署功能。
- [test_lite_arm_cpu_cpp 使用](docs/test_lite_arm_cpu_cpp.md):测试基于Paddle-Lite的ARM CPU端c++预测部署功能。 - [test_lite_arm_cpp 使用](docs/test_lite_arm_cpp.md):测试基于Paddle-Lite的ARM CPU端c++预测部署功能。
- [test_paddle2onnx 使用](docs/test_paddle2onnx.md):测试Paddle2ONNX的模型转化功能,并验证正确性。 - [test_paddle2onnx 使用](docs/test_paddle2onnx.md):测试Paddle2ONNX的模型转化功能,并验证正确性。
import numpy as np
import os
import sys
import platform
import yaml
import time
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from utils import get_logger, print_dict
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
self.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=')
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
default_config = {'Global': {'debug': False, }}
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
merge_config(default_config)
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
return global_config
def merge_config(config):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
else:
global_config[key] = value
else:
sub_keys = key.split('.')
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic)
if is_train:
# save_config
save_model_dir = config['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['use_gpu']
print_dict(config, logger)
return config, logger
if __name__ == "__main__":
config, logger = preprocess(is_train=False)
# print(config)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference from : https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cc
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void relu_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
}
}
template <typename data_t>
void relu_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
for (int i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));
return {out};
}
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
grad_x.reshape(x.shape());
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out);
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_forward(x);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_forward(x);
} else {
throw std::runtime_error("Not implemented.");
}
}
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
// TODO(chenweihang): Check Input
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward(x, out, grad_out);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward(x, out, grad_out);
} else {
throw std::runtime_error("Not implemented.");
}
}
PD_BUILD_OP(custom_relu)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForward));
PD_BUILD_GRAD_OP(custom_relu)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackward));
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// reference https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cu
#include "paddle/extension.h"
template <typename data_t>
__global__ void relu_cuda_forward_kernel(const data_t* x,
data_t* y,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
y[i] = max(x[i], static_cast<data_t>(0.));
}
}
template <typename data_t>
__global__ void relu_cuda_backward_kernel(const data_t* dy,
const data_t* y,
data_t* dx,
const int num) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
}
}
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
out.reshape(x.shape());
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
}));
return {out};
}
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
grad_x.reshape(x.shape());
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(x.place()),
numel);
}));
return {grad_x};
}
import paddle
import paddle.nn as nn
from paddle.vision.transforms import Compose, Normalize
from paddle.utils.cpp_extension import load
from paddle.inference import Config
from paddle.inference import create_predictor
import numpy as np
EPOCH_NUM = 4
BATCH_SIZE = 64
# jit compile custom op
custom_ops = load(
name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"])
class LeNet(nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = nn.Conv2D(
in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
self.linear2 = nn.Linear(in_features=120, out_features=84)
self.linear3 = nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = custom_ops.custom_relu(x)
x = self.max_pool1(x)
x = custom_ops.custom_relu(x)
x = self.conv2(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.linear1(x)
x = custom_ops.custom_relu(x)
x = self.linear2(x)
x = custom_ops.custom_relu(x)
x = self.linear3(x)
return x
# set device
paddle.set_device("gpu")
# model
net = LeNet()
loss_fn = nn.CrossEntropyLoss()
opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())
# data loader
transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='CHW')])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)
# train
for epoch_id in range(EPOCH_NUM):
for batch_id, (image, label) in enumerate(train_loader()):
out = net(image)
loss = loss_fn(out, label)
loss.backward()
if batch_id % 300 == 0:
print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id,
np.mean(loss.numpy())))
opt.step()
opt.clear_grad()
import numpy as np
import paddle
import os
import cv2
import glob
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
class DecodeImage(object):
""" decode image """
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
def __call__(self, data):
img = data['image']
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
data['src_image'] = img
return data
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
if isinstance(scale, str):
scale = eval(scale)
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
img.astype('float32') * self.scale - self.mean) / self.std
return data
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self, **kwargs):
pass
def __call__(self, data):
img = data['image']
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
data['image'] = img.transpose((2, 0, 1))
src_img = data['src_image']
from PIL import Image
if isinstance(img, Image.Image):
src_img = np.array(src_img)
data['src_image'] = img.transpose((2, 0, 1))
return data
class SimpleDataset(nn.Dataset):
def __init__(self, config, mode, logger, seed=None):
self.logger = logger
self.mode = mode.lower()
data_dir = config['Train']['data_dir']
imgs_list = self.get_image_list(data_dir)
self.ops = create_operators(cfg['transforms'], None)
def get_image_list(self, img_dir):
imgs = glob.glob(os.path.join(img_dir, "*.png"))
if len(imgs) == 0:
raise ValueError(f"not any images founded in {img_dir}")
return imgs
def __getitem__(self, idx):
return None
import numpy as np
from paddle.vision.datasets import Cifar100
from paddle.vision.transforms import Normalize
from paddle.fluid.dataloader.collate import default_collate_fn
import signal
import os
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
def term_mp(sig_num, frame):
""" kill all child processes
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
return
def build_dataloader(mode,
batch_size=4,
seed=None,
num_workers=0,
device='gpu:0'):
normalize = Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='HWC')
if mode.lower() == "train":
dataset = Cifar100(mode=mode, transform=normalize)
elif mode.lower() in ["test", 'valid', 'eval']:
dataset = Cifar100(mode="test", transform=normalize)
else:
raise ValueError(f"{mode} should be one of ['train', 'test']")
# define batch sampler
batch_sampler = DistributedBatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=False)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader
# cifar100 = Cifar100(mode='train', transform=normalize)
# data = cifar100[0]
# image, label = data
# reader = build_dataloader('train')
# for idx, data in enumerate(reader):
# print(idx, data[0].shape, data[1].shape)
# if idx >= 10:
# break
import pickle as p
import numpy as np
from PIL import Image
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='bytes')
# 以字典的形式取出数据
X = datadict[b'data']
Y = datadict[b'fine_labels']
try:
X = X.reshape(10000, 3, 32, 32)
except:
X = X.reshape(50000, 3, 32, 32)
Y = np.array(Y)
print(Y.shape)
return X, Y
if __name__ == "__main__":
mode = "train"
imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
with open(f'./cifar-100-python/{mode}_imgs/img_label.txt', 'a+') as f:
for i in range(imgY.shape[0]):
f.write('img' + str(i) + ' ' + str(imgY[i]) + '\n')
for i in range(imgX.shape[0]):
imgs = imgX[i]
img0 = imgs[0]
img1 = imgs[1]
img2 = imgs[2]
i0 = Image.fromarray(img0)
i1 = Image.fromarray(img1)
i2 = Image.fromarray(img2)
img = Image.merge("RGB", (i0, i1, i2))
name = "img" + str(i) + ".png"
img.save(f"./cifar-100-python/{mode}_imgs/" + name, "png")
print("save successfully!")
import paddle
import paddle.nn.functional as F
class Loss(object):
"""
Loss
"""
def __init__(self, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True
else:
self._epsilon = None
self._label_smoothing = False
def _labelsmoothing(self, target):
if target.shape[-1] != self._class_dim:
one_hot_target = F.one_hot(target, self._class_dim)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
return soft_target
def _crossentropy(self, input, target, use_pure_fp16=False):
if self._label_smoothing:
target = self._labelsmoothing(target)
input = -F.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1)
else:
cost = F.cross_entropy(input=input, label=target)
if use_pure_fp16:
avg_cost = paddle.sum(cost)
else:
avg_cost = paddle.mean(cost)
return avg_cost
def __call__(self, input, target):
return self._crossentropy(input, target)
def build_loss(config, epsilon=None):
class_dim = config['class_dim']
loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
return loss_func
class LossDistill(Loss):
def __init__(self, model_name_list, class_dim=1000, epsilon=None):
assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
self._class_dim = class_dim
if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
self._epsilon = epsilon
self._label_smoothing = True
else:
self._epsilon = None
self._label_smoothing = False
self.model_name_list = model_name_list
assert len(self.model_name_list) > 1, "error"
def __call__(self, input, target):
losses = {}
for k in self.model_name_list:
inp = input[k]
losses[k] = self._crossentropy(inp, target)
return losses
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
p1 = F.softmax(p1, axis=-1)
p2 = F.softmax(p2, axis=-1)
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss)
elif reduction == "none" or reduction is None:
return loss
else:
loss = paddle.sum(loss)
return loss
class DMLLoss(object):
def __init__(self, model_name_pairs, mode='js'):
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.kljs_loss = KLJSLoss(mode=mode)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def __call__(self, predicts, target=None):
loss_dict = dict()
for pairs in self.model_name_pairs:
p1 = predicts[pairs[0]]
p2 = predicts[pairs[1]]
loss_dict[pairs[0] + "_" + pairs[1]] = self.kljs_loss(p1, p2)
return loss_dict
# def build_distill_loss(config, epsilon=None):
# class_dim = config['class_dim']
# loss = LossDistill(model_name_list=['student', 'student1'], )
# return loss_func
import paddle
import paddle.nn.functional as F
from collections import OrderedDict
def create_metric(out,
label,
architecture=None,
topk=5,
classes_num=1000,
use_distillation=False,
mode="train"):
"""
Create measures of model accuracy, such as top1 and top5
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
use_distillation(bool): whether to use distillation training
mode(str): mode, train/valid
Returns:
fetchs(dict): dict of measures
"""
# if architecture["name"] == "GoogLeNet":
# assert len(out) == 3, "GoogLeNet should have 3 outputs"
# out = out[0]
# else:
# # just need student label to get metrics
# if use_distillation:
# out = out[1]
softmax_out = F.softmax(out)
fetchs = OrderedDict()
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
topk = paddle.distributed.all_reduce(
topk, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs['top1'] = top1
topk_name = 'top{}'.format(k)
fetchs[topk_name] = topk
return fetchs
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.functional import hardswish, hardsigmoid
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.regularizer import L2Decay
import math
from paddle.utils.cpp_extension import load
# jit compile custom op
custom_ops = load(
name="custom_jit_ops",
sources=["./custom_op/custom_relu_op.cc", "./custom_op/custom_relu_op.cu"])
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileNetV3(nn.Layer):
def __init__(self,
scale=1.0,
model_name="small",
dropout_prob=0.2,
class_dim=1000,
use_custom_relu=False):
super(MobileNetV3, self).__init__()
self.use_custom_relu = use_custom_relu
inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1],
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hardswish", 2],
[3, 200, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 184, 80, False, "hardswish", 1],
[3, 480, 112, True, "hardswish", 1],
[3, 672, 112, True, "hardswish", 1],
[5, 672, 160, True, "hardswish", 2],
[5, 960, 160, True, "hardswish", 1],
[5, 960, 160, True, "hardswish", 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hardswish", 2],
[5, 240, 40, True, "hardswish", 1],
[5, 240, 40, True, "hardswish", 1],
[5, 120, 48, True, "hardswish", 1],
[5, 144, 48, True, "hardswish", 1],
[5, 288, 96, True, "hardswish", 2],
[5, 576, 96, True, "hardswish", 1],
[5, 576, 96, True, "hardswish", 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
self.conv1 = ConvBNLayer(
in_c=3,
out_c=make_divisible(inplanes * scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
if_act=True,
act="hardswish",
name="conv1",
use_custom_relu=self.use_custom_relu)
self.block_list = []
i = 0
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in self.cfg:
block = self.add_sublayer(
"conv" + str(i + 2),
ResidualUnit(
in_c=inplanes,
mid_c=make_divisible(scale * exp),
out_c=make_divisible(scale * c),
filter_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2),
use_custom_relu=self.use_custom_relu))
self.block_list.append(block)
inplanes = make_divisible(scale * c)
i += 1
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze),
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act="hardswish",
name="conv_last",
use_custom_relu=self.use_custom_relu)
self.pool = AdaptiveAvgPool2D(1)
self.last_conv = Conv2D(
in_channels=make_divisible(scale * self.cls_ch_squeeze),
out_channels=self.cls_ch_expand,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=False)
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.out = Linear(
self.cls_ch_expand,
class_dim,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
def forward(self, inputs, label=None):
x = self.conv1(inputs)
for block in self.block_list:
x = block(x)
x = self.last_second_conv(x)
x = self.pool(x)
x = self.last_conv(x)
x = hardswish(x)
x = self.dropout(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.out(x)
return x
class ConvBNLayer(nn.Layer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
use_cudnn=True,
name="",
use_custom_relu=False):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = Conv2D(
in_channels=in_c,
out_channels=out_c,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(),
bias_attr=False)
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
# moving_mean_name=name + "_bn_mean",
# moving_variance_name=name + "_bn_variance")
self.use_custom_relu = use_custom_relu
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
if self.use_custom_relu:
x = custom_ops.custom_relu(x)
else:
x = F.relu(x)
elif self.act == "hardswish":
x = hardswish(x)
else:
print("The activation function is selected incorrectly.")
exit()
return x
class ResidualUnit(nn.Layer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
act=None,
name='',
use_custom_relu=False):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
self.use_custom_relu = use_custom_relu
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand",
use_custom_relu=self.use_custom_relu)
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding=int((filter_size - 1) // 2),
num_groups=mid_c,
if_act=True,
act=act,
name=name + "_depthwise",
use_custom_relu=self.use_custom_relu)
if self.if_se:
self.mid_se = SEModule(mid_c, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear",
use_custom_relu=self.use_custom_relu)
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = paddle.add(inputs, x)
return x
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(),
bias_attr=ParamAttr())
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs, slope=0.2, offset=0.5)
return paddle.multiply(x=inputs, y=outputs)
def MobileNetV3_small_x0_35(**args):
model = MobileNetV3(model_name="small", scale=0.35, **args)
return model
def MobileNetV3_small_x0_5(**args):
model = MobileNetV3(model_name="small", scale=0.5, **args)
return model
def MobileNetV3_small_x0_75(**args):
model = MobileNetV3(model_name="small", scale=0.75, **args)
return model
def MobileNetV3_small_x1_0(**args):
model = MobileNetV3(model_name="small", scale=1.0, **args)
return model
def MobileNetV3_small_x1_25(**args):
model = MobileNetV3(model_name="small", scale=1.25, **args)
return model
def MobileNetV3_large_x0_35(**args):
model = MobileNetV3(model_name="large", scale=0.35, **args)
return model
def MobileNetV3_large_x0_5(**args):
model = MobileNetV3(model_name="large", scale=0.5, **args)
return model
def MobileNetV3_large_x0_75(**args):
model = MobileNetV3(model_name="large", scale=0.75, **args)
return model
def MobileNetV3_large_x1_0(**args):
model = MobileNetV3(model_name="large", scale=1.0, **args)
return model
def MobileNetV3_large_x1_25(**args):
model = MobileNetV3(model_name="large", scale=1.25, **args)
return
class DistillMV3(nn.Layer):
def __init__(self,
scale=1.0,
model_name="small",
dropout_prob=0.2,
class_dim=1000,
args=None,
use_custom_relu=False):
super(DistillMV3, self).__init__()
self.student = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
use_custom_relu=use_custom_relu)
self.student1 = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
use_custom_relu=use_custom_relu)
def forward(self, inputs, label=None):
predicts = dict()
predicts['student'] = self.student(inputs, label)
predicts['student1'] = self.student1(inputs, label)
return predicts
def distillmv3_large_x0_5(**args):
model = DistillMV3(model_name="large", scale=0.5, **args)
return model
class SiameseMV3(nn.Layer):
def __init__(self,
scale=1.0,
model_name="small",
dropout_prob=0.2,
class_dim=1000,
args=None,
use_custom_relu=False):
super(SiameseMV3, self).__init__()
self.net = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
use_custom_relu=use_custom_relu)
self.net1 = MobileNetV3(
model_name=model_name,
scale=scale,
class_dim=class_dim,
use_custom_relu=use_custom_relu)
def forward(self, inputs, label=None):
# net
x = self.net.conv1(inputs)
for block in self.net.block_list:
x = block(x)
# net1
x1 = self.net1.conv1(inputs)
for block in self.net1.block_list:
x1 = block(x1)
# add
x = x + x1
x = self.net.last_second_conv(x)
x = self.net.pool(x)
x = self.net.last_conv(x)
x = hardswish(x)
x = self.net.dropout(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.net.out(x)
return x
def siamese_mv3(class_dim, use_custom_relu):
model = SiameseMV3(
scale=0.5,
model_name="large",
class_dim=class_dim,
use_custom_relu=use_custom_relu)
return model
def build_model(config):
model_type = config['model_type']
if model_type == "cls":
class_dim = config['MODEL']['class_dim']
use_custom_relu = config['MODEL']['use_custom_relu']
if 'siamese' in config['MODEL'] and config['MODEL']['siamese'] is True:
model = siamese_mv3(
class_dim=class_dim, use_custom_relu=use_custom_relu)
else:
model = MobileNetV3_large_x0_5(
class_dim=class_dim, use_custom_relu=use_custom_relu)
elif model_type == "cls_distill":
class_dim = config['MODEL']['class_dim']
use_custom_relu = config['MODEL']['use_custom_relu']
model = distillmv3_large_x0_5(
class_dim=class_dim, use_custom_relu=use_custom_relu)
elif model_type == "cls_distill_multiopt":
class_dim = config['MODEL']['class_dim']
use_custom_relu = config['MODEL']['use_custom_relu']
model = distillmv3_large_x0_5(
class_dim=100, use_custom_relu=use_custom_relu)
else:
raise ValueError("model_type should be one of ['']")
return model
class_dim: 100
total_images: 50000
epochs: 1000
topk: 5
save_model_dir: ./output/
use_gpu: True
model_type: cls_distill
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.001
warmup_epoch: 5
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.00002
TRAIN:
batch_size: 1280
num_workers: 4
VALID:
batch_size: 64
num_workers: 4
class_dim: 100
total_images: 50000
epoch: 1000
topk: 5
save_model_dir: ./output/
use_gpu: True
model_type: cls
use_custom_relu: false
pretrained_model:
checkpoints:
save_model_dir: ./output/cls/
# slim
quant_train: false
prune_train: false
MODEL:
class_dim: 100
use_custom_relu: False
siamese: False
AMP:
use_amp: False
scale_loss: 1024.0
use_dynamic_loss_scale: True
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.001
warmup_epoch: 5
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.00002
TRAIN:
batch_size: 1280
num_workers: 4
VALID:
batch_size: 64
num_workers: 4
import sys
import math
from paddle.optimizer.lr import LinearWarmup
from paddle.optimizer.lr import PiecewiseDecay
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.optimizer.lr import ExponentialDecay
import paddle
import paddle.regularizer as regularizer
from copy import deepcopy
class Cosine(CosineAnnealingDecay):
"""
Cosine learning rate decay
lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
"""
def __init__(self, lr, step_each_epoch, epochs, **kwargs):
super(Cosine, self).__init__(
learning_rate=lr,
T_max=step_each_epoch * epochs, )
self.update_specified = False
class Piecewise(PiecewiseDecay):
"""
Piecewise learning rate decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(list): piecewise decay epochs
gamma(float): decay factor
"""
def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
boundaries = [step_each_epoch * e for e in decay_epochs]
lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
super(Piecewise, self).__init__(boundaries=boundaries, values=lr_values)
self.update_specified = False
class CosineWarmup(LinearWarmup):
"""
Cosine learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): cosine decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
epochs(int): total training epochs
warmup_epoch(int): epoch num of warmup
"""
def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
epochs, warmup_epoch)
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
super(CosineWarmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
end_lr=end_lr)
self.update_specified = False
class ExponentialWarmup(LinearWarmup):
"""
Exponential learning rate decay with warmup
[0, warmup_epoch): linear warmup
[warmup_epoch, epochs): Exponential decay
Args:
lr(float): initial learning rate
step_each_epoch(int): steps each epoch
decay_epochs(float): decay epochs
decay_rate(float): decay rate
warmup_epoch(int): epoch num of warmup
"""
def __init__(self,
lr,
step_each_epoch,
decay_epochs=2.4,
decay_rate=0.97,
warmup_epoch=5,
**kwargs):
warmup_step = warmup_epoch * step_each_epoch
start_lr = 0.0
end_lr = lr
lr_sch = ExponentialDecay(lr, decay_rate)
super(ExponentialWarmup, self).__init__(
learning_rate=lr_sch,
warmup_steps=warmup_step,
start_lr=start_lr,
end_lr=end_lr)
# NOTE: hac method to update exponential lr scheduler
self.update_specified = True
self.update_start_step = warmup_step
self.update_step_interval = int(decay_epochs * step_each_epoch)
self.step_each_epoch = step_each_epoch
class LearningRateBuilder():
"""
Build learning rate variable
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
Args:
function(str): class name of learning rate
params(dict): parameters used for init the class
"""
def __init__(self,
function='Linear',
params={'lr': 0.1,
'steps': 100,
'end_lr': 0.0}):
self.function = function
self.params = params
def __call__(self):
mod = sys.modules[__name__]
lr = getattr(mod, self.function)(**self.params)
return lr
class L1Decay(object):
"""
L1 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L1Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L1Decay(self.factor)
return reg
class L2Decay(object):
"""
L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L2Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L2Decay(self.factor)
return reg
class Momentum(object):
"""
Simple Momentum optimizer with velocity state.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def __init__(self,
learning_rate,
momentum,
parameter_list=None,
regularization=None,
**args):
super(Momentum, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.parameter_list = parameter_list
self.regularization = regularization
def __call__(self):
opt = paddle.optimizer.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
parameters=self.parameter_list,
weight_decay=self.regularization)
return opt
class RMSProp(object):
"""
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
Args:
learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element.
momentum (float) - Momentum factor.
rho (float) - rho value in equation.
epsilon (float) - avoid division by zero, default is 1e-6.
regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
"""
def __init__(self,
learning_rate,
momentum,
rho=0.95,
epsilon=1e-6,
parameter_list=None,
regularization=None,
**args):
super(RMSProp, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.rho = rho
self.epsilon = epsilon
self.parameter_list = parameter_list
self.regularization = regularization
def __call__(self):
opt = paddle.optimizer.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
rho=self.rho,
epsilon=self.epsilon,
parameters=self.parameter_list,
weight_decay=self.regularization)
return opt
class OptimizerBuilder(object):
"""
Build optimizer
Args:
function(str): optimizer name of learning rate
params(dict): parameters used for init the class
regularizer (dict): parameters used for create regularization
"""
def __init__(self,
function='Momentum',
params={'momentum': 0.9},
regularizer=None):
self.function = function
self.params = params
# create regularizer
if regularizer is not None:
mod = sys.modules[__name__]
reg_func = regularizer['function'] + 'Decay'
del regularizer['function']
reg = getattr(mod, reg_func)(**regularizer)()
self.params['regularization'] = reg
def __call__(self, learning_rate, parameter_list=None):
mod = sys.modules[__name__]
opt = getattr(mod, self.function)
return opt(learning_rate=learning_rate,
parameter_list=parameter_list,
**self.params)()
def create_optimizer(config, parameter_list=None):
"""
Create an optimizer using config, usually including
learning rate and regularization.
Args:
config(dict): such as
{
'LEARNING_RATE':
{'function': 'Cosine',
'params': {'lr': 0.1}
},
'OPTIMIZER':
{'function': 'Momentum',
'params':{'momentum': 0.9},
'regularizer':
{'function': 'L2', 'factor': 0.0001}
}
}
Returns:
an optimizer instance
"""
# create learning_rate instance
lr_config = config['LEARNING_RATE']
lr_config['params'].update({
'epochs': config['epoch'],
'step_each_epoch':
config['total_images'] // config['TRAIN']['batch_size'],
})
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
opt_config = deepcopy(config['OPTIMIZER'])
opt = OptimizerBuilder(**opt_config)
return opt(lr, parameter_list), lr
def create_multi_optimizer(config, parameter_list=None):
"""
"""
# create learning_rate instance
lr_config = config['LEARNING_RATE']
lr_config['params'].update({
'epochs': config['epoch'],
'step_each_epoch':
config['total_images'] // config['TRAIN']['batch_size'],
})
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
opt_config = deepcopy.copy(config['OPTIMIZER'])
opt = OptimizerBuilder(**opt_config)
return opt(lr, parameter_list), lr
# TIPC Linux端补充训练功能测试
Linux端基础训练预测功能测试的主程序为test_train_python.sh,可以测试基于Python的模型训练、评估等基本功能,包括裁剪、量化、蒸馏训练。
![](./tipc_train.png)
测试链条如上图所示,主要测试内容有带共享权重,自定义OP的模型的正常训练和slim相关功能训练流程是否正常。
# 2. 测试流程
本节介绍补充链条的测试流程
## 2.1 安装依赖
- 安装PaddlePaddle >= 2.2
- 安装其他依赖
```
pip3 install -r requirements.txt
```
## 2.2 功能测试
`test_train_python.sh`包含2种运行模式,每种模式的运行数据不同,分别用于测试训练是否正常,分别是:
- 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
```
- 模式2:whole_train_whole_infer,使用全量数据训练,用于快速验证训练到预测的走通流程,验证模型最终训练精度;
```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer'
```
如果是运行量化裁剪等训练方式,需要使用不同的配置文件。量化训练的测试指令如下:
```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_PACT.txt 'lite_train_lite_infer'
```
同理,FPGM裁剪的运行方式如下:
```
bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_FPGM.txt 'lite_train_lite_infer'
```
运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式运行后,在test_tipc/extra_output文件夹有以下文件:
```
test_tipc/output/
|- results_python.log # 运行指令状态的日志
```
其中results_python.log中包含了每条指令的运行状态,如果运行成功会输出:
```
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=20 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=True !
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=False !
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=True !
Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill_multiopt MODEL.siamese=False !
```
import paddleslim
import paddle
import numpy as np
from paddleslim.dygraph import FPGMFilterPruner
def prune_model(model, input_shape, prune_ratio=0.1):
flops = paddle.flops(model, input_shape)
pruner = FPGMFilterPruner(model, input_shape)
params_sensitive = {}
for param in model.parameters():
if 'transpose' not in param.name and 'linear' not in param.name:
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
params_sensitive[param.name] = prune_ratio
plan = pruner.prune_vars(params_sensitive, [0])
flops = paddle.flops(model, input_shape)
return model
import paddle
import numpy as np
import os
import paddle.nn as nn
import paddleslim
class PACT(paddle.nn.Layer):
def __init__(self):
super(PACT, self).__init__()
alpha_attr = paddle.ParamAttr(
name=self.full_name() + ".pact",
initializer=paddle.nn.initializer.Constant(value=20),
learning_rate=1.0,
regularizer=paddle.regularizer.L2Decay(2e-5))
self.alpha = self.create_parameter(
shape=[1], attr=alpha_attr, dtype='float32')
def forward(self, x):
out_left = paddle.nn.functional.relu(x - self.alpha)
out_right = paddle.nn.functional.relu(-self.alpha - x)
x = x - out_left + out_right
return x
quant_config = {
# weight preprocess type, default is None and no preprocessing is performed.
'weight_preprocess_type': None,
# activation preprocess type, default is None and no preprocessing is performed.
'activation_preprocess_type': None,
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. default is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# for dygraph quantization, layers of type in quantizable_layer_type will be quantized
'quantizable_layer_type': ['Conv2D', 'Linear'],
}
#!/bin/bash
function func_parser_key(){
strs=$1
IFS=":"
array=(${strs})
tmp=${array[0]}
echo ${tmp}
}
function func_parser_value(){
strs=$1
IFS=":"
array=(${strs})
tmp=${array[1]}
echo ${tmp}
}
function func_set_params(){
key=$1
value=$2
if [ ${key}x = "null"x ];then
echo " "
elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
echo " "
else
echo "${key}=${value}"
fi
}
function func_parser_params(){
strs=$1
MODE=$2
IFS=":"
array=(${strs})
key=${array[0]}
tmp=${array[1]}
IFS="|"
res=""
for _params in ${tmp[*]}; do
IFS="="
array=(${_params})
mode=${array[0]}
value=${array[1]}
if [[ ${mode} = ${MODE} ]]; then
IFS="|"
#echo $(func_set_params "${mode}" "${value}")
echo $value
break
fi
IFS="|"
done
echo ${res}
}
function status_check(){
last_status=$1 # the exit code
run_command=$2
run_log=$3
if [ $last_status -eq 0 ]; then
echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
else
echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
fi
}
\ No newline at end of file
#!/bin/bash
source test_tipc/common_func.sh
FILENAME=$1
# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer']
MODE=$2
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
gpu_list=$(func_parser_value "${lines[3]}")
train_use_gpu_key=$(func_parser_key "${lines[4]}")
train_use_gpu_value=$(func_parser_value "${lines[4]}")
autocast_list=$(func_parser_value "${lines[5]}")
autocast_key=$(func_parser_key "${lines[5]}")
epoch_key=$(func_parser_key "${lines[6]}")
epoch_num=$(func_parser_params "${lines[6]}" "${MODE}")
save_model_key=$(func_parser_key "${lines[7]}")
train_batch_key=$(func_parser_key "${lines[8]}")
train_batch_value=$(func_parser_params "${lines[8]}" "${MODE}")
pretrain_model_key=$(func_parser_key "${lines[9]}")
pretrain_model_value=$(func_parser_value "${lines[9]}")
checkpoints_key=$(func_parser_key "${lines[10]}")
checkpoints_value=$(func_parser_value "${lines[10]}")
use_custom_key=$(func_parser_key "${lines[11]}")
use_custom_list=$(func_parser_value "${lines[11]}")
model_type_key=$(func_parser_key "${lines[12]}")
model_type_list=$(func_parser_value "${lines[12]}")
use_share_conv_key=$(func_parser_key "${lines[13]}")
use_share_conv_list=$(func_parser_value "${lines[13]}")
run_train_py=$(func_parser_value "${lines[14]}")
LOG_PATH="./test_tipc/extra_output"
mkdir -p ${LOG_PATH}
status_log="${LOG_PATH}/results_python.log"
if [ ${MODE} = "lite_train_lite_infer" ] || [ ${MODE} = "whole_train_whole_infer" ]; then
IFS="|"
export Count=0
USE_GPU_KEY=(${train_use_gpu_value})
# select cpu\gpu\distribute training
for gpu in ${gpu_list[*]}; do
train_use_gpu=${USE_GPU_KEY[Count]}
Count=$(($Count + 1))
ips=""
if [ ${gpu} = "-1" ];then
env=""
elif [ ${#gpu} -le 1 ];then
env="export CUDA_VISIBLE_DEVICES=${gpu}"
eval ${env}
elif [ ${#gpu} -le 15 ];then
IFS=","
array=(${gpu})
env="export CUDA_VISIBLE_DEVICES=${array[0]}"
IFS="|"
else
IFS=";"
array=(${gpu})
ips=${array[0]}
gpu=${array[1]}
IFS="|"
env=" "
fi
for autocast in ${autocast_list[*]}; do
# set amp
if [ ${autocast} = "amp" ]; then
set_amp_config="AMP.use_amp=True"
else
set_amp_config=" "
fi
if [ ${run_train_py} = "null" ]; then
continue
fi
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
set_checkpoints=$(func_set_params "${checkpoints_key}" "${checkpoints_value}")
set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${train_use_gpu}")
for custom_op in ${use_custom_list[*]}; do
for model_type in ${model_type_list[*]}; do
for share_conv in ${use_share_conv_list[*]}; do
set_use_custom_op=$(func_set_params "${use_custom_key}" "${custom_op}")
set_model_type=$(func_set_params "${model_type_key}" "${model_type}")
set_use_share_conv=$(func_set_params "${use_share_conv_key}" "${share_conv}")
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
cmd="${python} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
elif [ ${#ips} -le 26 ];then # train with multi-gpu
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
fi
# run train
eval "unset CUDA_VISIBLE_DEVICES"
# echo $cmd
eval $cmd
status_check $? "${cmd}" "${status_log}"
done
done
done
done
done
fi
===========================train_params===========================
model_name:ch_PPOCRv2_det
python:python3.7
gpu_list:0|0,1
use_gpu:True|True
AMP.use_amp:True|False
epoch:lite_train_lite_infer=2|whole_train_whole_infer=1000
save_model_dir:./output/
TRAIN.batch_size:lite_train_lite_infer=1280|whole_train_whole_infer=1280
pretrained_model:null
checkpoints:null
use_custom_relu:False|True
model_type:cls|cls_distill|cls_distill_multiopt
MODEL.siamese:False|True
norm_train:train.py -c mv3_large_x0_5.yml -o
quant_train:False
prune_train:False
===========================train_params===========================
model_name:ch_PPOCRv2_det
python:python3.7
gpu_list:0|0,1
use_gpu:True|True
AMP.use_amp:True|False
epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
save_model_dir:./output/
TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
pretrained_model:null
checkpoints:null
use_custom_relu:False|True
model_type:cls|cls_distill|cls_distill_multiopt
MODEL.siamese:False|True
norm_train:train.py -c mv3_large_x0_5.yml -o prune_train=True
quant_train:False
prune_train:False
===========================train_params===========================
model_name:ch_PPOCRv2_det
python:python3.7
gpu_list:0|0,1
use_gpu:True|True
AMP.use_amp:True|False
epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
save_model_dir:./output/
TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
pretrained_model:null
checkpoints:null
use_custom_relu:False|True
model_type:cls|cls_distill|cls_distill_multiopt
MODEL.siamese:False|True
norm_train:train.py -c mv3_large_x0_5.yml -o quant_train=True
quant_train:False
prune_train:False
import paddle
import numpy as np
import os
import paddle.nn as nn
import paddle.distributed as dist
dist.get_world_size()
dist.init_parallel_env()
from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
from optimizer import create_optimizer
from data_loader import build_dataloader
from metric import create_metric
from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
from config import preprocess
import time
from paddleslim.dygraph.quant import QAT
from slim.slim_quant import PACT, quant_config
from slim.slim_fpgm import prune_model
from utils import load_model
def _mkdir_if_not_exist(path, logger):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def save_model(model,
optimizer,
model_path,
logger,
is_best=False,
prefix='ppocr',
**kwargs):
"""
save model to the target path
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(model.state_dict(), model_prefix + '.pdparams')
if type(optimizer) is list:
paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
else:
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# # save metric and config
# with open(model_prefix + '.states', 'wb') as f:
# pickle.dump(kwargs, f, protocol=2)
if is_best:
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
def amp_scaler(config):
if 'AMP' in config and config['AMP']['use_amp'] is True:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["AMP"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
False)
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
return scaler
else:
return None
def set_seed(seed):
paddle.seed(seed)
np.random.seed(seed)
def train(config, scaler=None):
EPOCH = config['epoch']
topk = config['topk']
batch_size = config['TRAIN']['batch_size']
num_workers = config['TRAIN']['num_workers']
train_loader = build_dataloader(
'train', batch_size=batch_size, num_workers=num_workers)
# build metric
metric_func = create_metric
# build model
# model = MobileNetV3_large_x0_5(class_dim=100)
model = build_model(config)
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.parameters())
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
logger.info(pre_str)
# about slim prune and quant
if "quant_train" in config and config['quant_train'] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config['prune_train'] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
# distribution
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func = build_loss(config)
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs, label)
# cal loss
avg_loss = loss_func(outs, label)
if scaler is None:
# backward
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {avg_loss.numpy()[0]}"
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model)
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
best_acc = acc
best_acc['epoch'] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
)
save_model(
model,
optimizer,
config['save_model_dir'],
logger,
is_best,
prefix="cls")
def train_distill(config, scaler=None):
EPOCH = config['epoch']
topk = config['topk']
batch_size = config['TRAIN']['batch_size']
num_workers = config['TRAIN']['num_workers']
train_loader = build_dataloader(
'train', batch_size=batch_size, num_workers=num_workers)
# build metric
metric_func = create_metric
# model = distillmv3_large_x0_5(class_dim=100)
model = build_model(config)
# pact quant train
if "quant_train" in config and config['quant_train'] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config['prune_train'] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.parameters())
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
logger.info(pre_str)
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
loss_func_js = KLJSLoss(mode='js')
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs['student'], label)
# cal loss
avg_loss = loss_func_distill(outs, label)['student'] + \
loss_func_distill(outs, label)['student1'] + \
loss_func_dml(outs, label)['student_student1']
# backward
if scaler is None:
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {avg_loss.numpy()[0]}"
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
best_acc = acc
best_acc['epoch'] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
)
save_model(
model,
optimizer,
config['save_model_dir'],
logger,
is_best,
prefix="cls_distill")
def train_distill_multiopt(config, scaler=None):
EPOCH = config['epoch']
topk = config['topk']
batch_size = config['TRAIN']['batch_size']
num_workers = config['TRAIN']['num_workers']
train_loader = build_dataloader(
'train', batch_size=batch_size, num_workers=num_workers)
# build metric
metric_func = create_metric
# model = distillmv3_large_x0_5(class_dim=100)
model = build_model(config)
# build_optimizer
optimizer, lr_scheduler = create_optimizer(
config, parameter_list=model.student.parameters())
optimizer1, lr_scheduler1 = create_optimizer(
config, parameter_list=model.student1.parameters())
# load model
pre_best_model_dict = load_model(config, model, optimizer)
if len(pre_best_model_dict) > 0:
pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
logger.info(pre_str)
# quant train
if "quant_train" in config and config['quant_train'] is True:
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
elif "prune_train" in config and config['prune_train'] is True:
model = prune_model(model, [1, 3, 32, 32], 0.1)
else:
pass
model.train()
model = paddle.DataParallel(model)
# build loss function
loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
loss_func_js = KLJSLoss(mode='js')
data_num = len(train_loader)
best_acc = {}
for epoch in range(EPOCH):
st = time.time()
for idx, data in enumerate(train_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
if scaler is not None:
with paddle.amp.auto_cast():
outs = model(img_batch)
else:
outs = model(img_batch)
# cal metric
acc = metric_func(outs['student'], label)
# cal loss
avg_loss = loss_func_distill(outs,
label)['student'] + loss_func_dml(
outs, label)['student_student1']
avg_loss1 = loss_func_distill(outs,
label)['student1'] + loss_func_dml(
outs, label)['student_student1']
if scaler is None:
# backward
avg_loss.backward(retain_graph=True)
optimizer.step()
optimizer.clear_grad()
avg_loss1.backward()
optimizer1.step()
optimizer1.clear_grad()
else:
scaled_avg_loss = scaler.scale(avg_loss)
scaled_avg_loss.backward()
scaler.minimize(optimizer, scaled_avg_loss)
scaled_avg_loss = scaler.scale(avg_loss1)
scaled_avg_loss.backward()
scaler.minimize(optimizer1, scaled_avg_loss)
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
if not isinstance(lr_scheduler1, float):
lr_scheduler1.step()
if idx % 10 == 0:
et = time.time()
strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
strs += f", batch_time: {round(et-st, 4)} s"
logger.info(strs)
st = time.time()
if epoch % 10 == 0:
acc = eval(config, model._layers.student)
if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
best_acc = acc
best_acc['epoch'] = epoch
is_best = True
else:
is_best = False
logger.info(
f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
)
save_model(
model, [optimizer, optimizer1],
config['save_model_dir'],
logger,
is_best,
prefix="cls_distill_multiopt")
def eval(config, model):
batch_size = config['VALID']['batch_size']
num_workers = config['VALID']['num_workers']
valid_loader = build_dataloader(
'test', batch_size=batch_size, num_workers=num_workers)
# build metric
metric_func = create_metric
outs = []
labels = []
for idx, data in enumerate(valid_loader):
img_batch, label = data
img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
label = paddle.unsqueeze(label, -1)
out = model(img_batch)
outs.append(out)
labels.append(label)
outs = paddle.concat(outs, axis=0)
labels = paddle.concat(labels, axis=0)
acc = metric_func(outs, labels)
strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
logger.info(strs)
return acc
if __name__ == "__main__":
config, logger = preprocess(is_train=False)
# AMP scaler
scaler = amp_scaler(config)
model_type = config['model_type']
if model_type == "cls":
train(config)
elif model_type == "cls_distill":
train_distill(config)
elif model_type == "cls_distill_multiopt":
train_distill_multiopt(config)
else:
raise ValueError("model_type should be one of ['']")
# single GPU
python3.7 train.py -c mv3_large_x0_5.yml
# distribute training
python3.7 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1' train.py -c mv3_large_x0_5.yml
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import functools
import paddle.distributed as dist
logger_initialized = {}
def print_dict(d, logger, delimiter=0):
"""
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", str(k)))
print_dict(v, logger, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", str(k)))
for value in v:
print_dict(value, logger, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if dist.get_rank() == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def load_model(config, model, optimizer=None):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
checkpoints = config.get('checkpoints')
pretrained_model = config.get('pretrained_model')
best_model_dict = {}
if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
assert os.path.exists(checkpoints + ".pdparams"), \
"The {}.pdparams does not exists!".format(checkpoints)
# load params from trained model
params = paddle.load(checkpoints + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for key, value in state_dict.items():
if key not in params:
logger.warning("{} not in loaded params {} !".format(
key, params.keys()))
continue
pre_value = params[key]
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
optimizer.set_state_dict(optim_dict)
else:
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
if os.path.exists(checkpoints + '.states'):
with open(checkpoints + '.states', 'rb') as f:
states_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
return best_model_dict
def load_pretrained_params(model, path):
logger = get_logger()
if path.endswith('.pdparams'):
path = path.replace('.pdparams', '')
assert os.path.exists(path + ".pdparams"), \
"The {}.pdparams does not exists!".format(path)
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path))
return model
...@@ -183,7 +183,7 @@ function func_inference(){ ...@@ -183,7 +183,7 @@ function func_inference(){
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue continue
fi fi
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then if [[ ${use_trt} = "False" && ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
continue continue
fi fi
for batch_size in ${batch_size_list[*]}; do for batch_size in ${batch_size_list[*]}; do
...@@ -227,7 +227,12 @@ if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then ...@@ -227,7 +227,12 @@ if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then
for infer_model in ${infer_model_dir_list[*]}; do for infer_model in ${infer_model_dir_list[*]}; do
# run export # run export
if [ ${infer_run_exports[Count]} != "null" ];then if [ ${infer_run_exports[Count]} != "null" ];then
save_infer_dir=$(dirname $infer_model) if [ ${MODE} = "klquant_whole_infer" ]; then
save_infer_dir="${infer_model}_klquant"
fi
if [ ${MODE} = "whole_infer" ]; then
save_infer_dir="${infer_model}"
fi
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}") set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}" export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
......
...@@ -61,7 +61,8 @@ def main(): ...@@ -61,7 +61,8 @@ def main():
else: else:
model_type = None model_type = None
best_model_dict = load_model(config, model) best_model_dict = load_model(
config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger):
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) config = merge_config(config, FLAGS.opt)
logger = get_logger() logger = get_logger()
# build post process # build post process
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
import tools.program as program
def to_tensor(data):
import numbers
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
to_tensor_idxs.append(idx)
data_dict[idx].append(v)
for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values())
class SerPredictor(object):
def __init__(self, config):
global_config = config['Global']
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
self.model = build_model(config['Architecture'])
load_model(
config, self.model, model_type=config['Architecture']["model_type"])
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
'token_type_ids', 'segment_offset_id', 'ocr_info',
'entities'
]
transforms.append(op)
global_config['infer_mode'] = True
self.ops = create_operators(config['Eval']['dataset']['transforms'],
global_config)
self.model.eval()
def __call__(self, img_path):
with open(img_path, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, self.ops)
batch = to_tensor(batch)
preds = self.model(batch)
post_result = self.post_process_class(
preds,
attention_masks=batch[4],
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
ser_engine = SerPredictor(config)
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(img_path)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
"ocr_info": result,
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
import paddle.distributed as dist
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config, check_gpu
from tools.infer_vqa_token_ser import SerPredictor
class ReArgsParser(ArgsParser):
def __init__(self):
super(ReArgsParser, self).__init__()
self.add_argument(
"-c_ser", "--config_ser", help="ser configuration file to use")
self.add_argument(
"-o_ser",
"--opt_ser",
nargs='+',
help="set ser configuration options ")
def parse_args(self, argv=None):
args = super(ReArgsParser, self).parse_args(argv)
assert args.config_ser is not None, \
"Please specify --config_ser=ser_configure_file_path."
args.opt_ser = self._parse_opt(args.opt_ser)
return args
def make_input(ser_inputs, ser_results):
entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
entities = ser_inputs[8][0]
ser_results = ser_results[0]
assert len(entities) == len(ser_results)
# entities
start = []
end = []
label = []
entity_idx_dict = {}
for i, (res, entity) in enumerate(zip(ser_results, entities)):
if res['pred'] == 'O':
continue
entity_idx_dict[len(start)] = i
start.append(entity['start'])
end.append(entity['end'])
label.append(entities_labels[res['pred']])
entities = dict(start=start, end=end, label=label)
# relations
head = []
tail = []
for i in range(len(entities["label"])):
for j in range(len(entities["label"])):
if entities["label"][i] == 1 and entities["label"][j] == 2:
head.append(i)
tail.append(j)
relations = dict(head=head, tail=tail)
batch_size = ser_inputs[0].shape[0]
entities_batch = []
relations_batch = []
entity_idx_dict_batch = []
for b in range(batch_size):
entities_batch.append(entities)
relations_batch.append(relations)
entity_idx_dict_batch.append(entity_idx_dict)
ser_inputs[8] = entities_batch
ser_inputs.append(relations_batch)
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
ser_inputs.pop(1)
return ser_inputs, entity_idx_dict_batch
class SerRePredictor(object):
def __init__(self, config, ser_config):
self.ser_engine = SerPredictor(ser_config)
# init re model
global_config = config['Global']
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
global_config)
# build model
self.model = build_model(config['Architecture'])
load_model(
config, self.model, model_type=config['Architecture']["model_type"])
self.model.eval()
def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine(img_path)
paddle.save(ser_inputs, 'ser_inputs.npy')
paddle.save(ser_results, 'ser_results.npy')
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
preds,
ser_results=ser_results,
entity_idx_dict_batch=entity_idx_dict_batch)
return post_result
def preprocess():
FLAGS = ReArgsParser().parse_args()
config = load_config(FLAGS.config)
config = merge_config(config, FLAGS.opt)
ser_config = load_config(FLAGS.config_ser)
ser_config = merge_config(ser_config, FLAGS.opt_ser)
logger = get_logger(name='root')
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
logger.info('{} re config {}'.format('*' * 10, '*' * 10))
print_dict(config, logger)
logger.info('\n')
logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
print_dict(ser_config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, ser_config, device, logger
if __name__ == '__main__':
config, ser_config, device, logger = preprocess()
os.makedirs(config['Global']['save_res_path'], exist_ok=True)
ser_re_engine = SerRePredictor(config, ser_config)
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result = ser_re_engine(img_path)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
"ser_resule": result,
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
...@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser): ...@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return config return config
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
default_config = {'Global': {'debug': False, }}
def load_config(file_path): def load_config(file_path):
""" """
Load config from yml/yaml file. Load config from yml/yaml file.
...@@ -94,38 +76,38 @@ def load_config(file_path): ...@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded. file_path (str): Path of the config file to be loaded.
Returns: global config Returns: global config
""" """
merge_config(default_config)
_, ext = os.path.splitext(file_path) _, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now" assert ext in ['.yml', '.yaml'], "only support yaml files for now"
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)) config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
return global_config return config
def merge_config(config): def merge_config(config, opts):
""" """
Merge config into global config. Merge config into global config.
Args: Args:
config (dict): Config to be merged. config (dict): Config to be merged.
Returns: global config Returns: global config
""" """
for key, value in config.items(): for key, value in opts.items():
if "." not in key: if "." not in key:
if isinstance(value, dict) and key in global_config: if isinstance(value, dict) and key in config:
global_config[key].update(value) config[key].update(value)
else: else:
global_config[key] = value config[key] = value
else: else:
sub_keys = key.split('.') sub_keys = key.split('.')
assert ( assert (
sub_keys[0] in global_config sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0]) config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2: if idx == len(sub_keys) - 2:
cur[sub_key] = value cur[sub_key] = value
else: else:
cur = cur[sub_key] cur = cur[sub_key]
return config
def check_gpu(use_gpu): def check_gpu(use_gpu):
...@@ -204,20 +186,24 @@ def train(config, ...@@ -204,20 +186,24 @@ def train(config,
model_type = None model_type = None
algorithm = config['Architecture']['algorithm'] algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict: start_epoch = best_model_dict[
start_epoch = best_model_dict['start_epoch'] 'start_epoch'] if 'start_epoch' in best_model_dict else 1
else:
start_epoch = 1 train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for epoch in range(start_epoch, epoch_num + 1): for epoch in range(start_epoch, epoch_num + 1):
train_dataloader = build_dataloader( if train_dataloader.dataset.need_reset:
config, 'Train', device, logger, seed=epoch) train_dataloader = build_dataloader(
train_reader_cost = 0.0 config, 'Train', device, logger, seed=epoch)
train_run_cost = 0.0 max_iter = len(train_dataloader) - 1 if platform.system(
total_samples = 0 ) == "Windows" else len(train_dataloader)
reader_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
...@@ -239,10 +225,11 @@ def train(config, ...@@ -239,10 +225,11 @@ def train(config,
else: else:
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type == "kie": elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
avg_loss = loss['loss'] avg_loss = loss['loss']
...@@ -256,6 +243,7 @@ def train(config, ...@@ -256,6 +243,7 @@ def train(config,
optimizer.clear_grad() optimizer.clear_grad()
train_run_cost += time.time() - train_start train_run_cost += time.time() - train_start
global_step += 1
total_samples += len(images) total_samples += len(images)
if not isinstance(lr_scheduler, float): if not isinstance(lr_scheduler, float):
...@@ -285,12 +273,13 @@ def train(config, ...@@ -285,12 +273,13 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)): (idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format( strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost / epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, (train_reader_cost + train_run_cost) / print_batch_step, (train_reader_cost + train_run_cost) /
print_batch_step, total_samples, print_batch_step, total_samples / print_batch_step,
total_samples / (train_reader_cost + train_run_cost)) total_samples / (train_reader_cost + train_run_cost))
logger.info(strs) logger.info(strs)
train_reader_cost = 0.0 train_reader_cost = 0.0
train_run_cost = 0.0 train_run_cost = 0.0
total_samples = 0 total_samples = 0
...@@ -330,6 +319,7 @@ def train(config, ...@@ -330,6 +319,7 @@ def train(config,
optimizer, optimizer,
save_model_dir, save_model_dir,
logger, logger,
config,
is_best=True, is_best=True,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
...@@ -344,8 +334,7 @@ def train(config, ...@@ -344,8 +334,7 @@ def train(config,
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator), vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
best_model_dict[main_indicator], best_model_dict[main_indicator],
global_step) global_step)
global_step += 1
optimizer.clear_grad()
reader_start = time.time() reader_start = time.time()
if dist.get_rank() == 0: if dist.get_rank() == 0:
save_model( save_model(
...@@ -353,6 +342,7 @@ def train(config, ...@@ -353,6 +342,7 @@ def train(config,
optimizer, optimizer,
save_model_dir, save_model_dir,
logger, logger,
config,
is_best=False, is_best=False,
prefix='latest', prefix='latest',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
...@@ -364,6 +354,7 @@ def train(config, ...@@ -364,6 +354,7 @@ def train(config,
optimizer, optimizer,
save_model_dir, save_model_dir,
logger, logger,
config,
is_best=False, is_best=False,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
...@@ -401,19 +392,28 @@ def eval(model, ...@@ -401,19 +392,28 @@ def eval(model,
start = time.time() start = time.time()
if model_type == 'table' or extra_input: if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
elif model_type == "kie": elif model_type in ["kie", 'vqa']:
preds = model(batch) preds = model(batch)
else: else:
preds = model(images) preds = model(images)
batch = [item.numpy() for item in batch]
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
batch_numpy.append(item.numpy())
else:
batch_numpy.append(item)
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
if model_type in ['table', 'kie']: if model_type in ['table', 'kie']:
eval_class(preds, batch) eval_class(preds, batch_numpy)
elif model_type in ['vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
else: else:
post_result = post_process_class(preds, batch[1]) post_result = post_process_class(preds, batch_numpy[1])
eval_class(post_result, batch) eval_class(post_result, batch_numpy)
pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
...@@ -479,9 +479,9 @@ def preprocess(is_train=False): ...@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) config = merge_config(config, FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options} profile_dic = {"profiler_options": FLAGS.profiler_options}
merge_config(profile_dic) config = merge_config(config, profile_dic)
if is_train: if is_train:
# save_config # save_config
...@@ -503,20 +503,15 @@ def preprocess(is_train=False): ...@@ -503,20 +503,15 @@ def preprocess(is_train=False):
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
] ]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl']: if config['Global']['use_visualdl'] and dist.get_rank() == 0:
from visualdl import LogWriter from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
...@@ -27,8 +27,6 @@ import yaml ...@@ -27,8 +27,6 @@ import yaml
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
paddle.seed(2)
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss from ppocr.losses import build_loss
...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer ...@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model from ppocr.utils.save_load import load_model
from ppocr.utils.utility import set_seed
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer): ...@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_model(config, model, optimizer) pre_best_model_dict = load_model(config, model, optimizer,
config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
...@@ -145,5 +145,7 @@ def test_reader(config, device, logger): ...@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True) config, device, logger, vdl_writer = program.preprocess(is_train=True)
seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
set_seed(seed)
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册