diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index 6477ea07025c09303e255ba0118f1b9a4d7fbb8a..3ee5eb60450be0c806316f70dea9b8d4c5f31503 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -61,7 +61,7 @@ from combobox import ComboBox
from libs.constants import *
from libs.utils import *
from libs.settings import Settings
-from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR
+from libs.shape import Shape, DEFAULT_LINE_COLOR, DEFAULT_FILL_COLOR,DEFAULT_LOCK_COLOR
from libs.stringBundle import StringBundle
from libs.canvas import Canvas
from libs.zoomWidget import ZoomWidget
@@ -101,6 +101,8 @@ class MainWindow(QMainWindow, WindowMixin):
def __init__(self, lang="ch", gpu=False, defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None):
super(MainWindow, self).__init__()
self.setWindowTitle(__appname__)
+ self.setWindowState(Qt.WindowMaximized) # set window max
+ self.activateWindow() # PPOCRLabel goes to the front when activate
# Load setting in the main thread
self.settings = Settings()
@@ -126,7 +128,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelHist = []
self.lastOpenDir = None
self.result_dic = []
-
+ self.result_dic_locked = []
self.changeFileFolder = False
self.haveAutoReced = False
self.labelFile = None
@@ -178,7 +180,8 @@ class MainWindow(QMainWindow, WindowMixin):
fileListContainer = QWidget()
fileListContainer.setLayout(filelistLayout)
- self.filedock = QDockWidget(getStr('fileList'), self)
+ self.fileListName = getStr('fileList')
+ self.filedock = QDockWidget(self.fileListName, self)
self.filedock.setObjectName(getStr('files'))
self.filedock.setWidget(fileListContainer)
self.addDockWidget(Qt.LeftDockWidgetArea, self.filedock)
@@ -394,7 +397,8 @@ class MainWindow(QMainWindow, WindowMixin):
'w', 'objects', getStr('crtBoxDetail'), enabled=False)
delete = action(getStr('delBox'), self.deleteSelectedShape,
- 'backspace', 'delete', getStr('delBoxDetail'), enabled=False)
+ 'Alt+X', 'delete', getStr('delBoxDetail'), enabled=False)
+
copy = action(getStr('dupBox'), self.copySelectedShape,
'Ctrl+C', 'copy', getStr('dupBoxDetail'),
enabled=False)
@@ -405,6 +409,7 @@ class MainWindow(QMainWindow, WindowMixin):
showAll = action(getStr('showBox'), partial(self.togglePolygons, True),
'Ctrl+A', 'hide', getStr('showAllBoxDetail'),
enabled=False)
+
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
@@ -476,6 +481,10 @@ class MainWindow(QMainWindow, WindowMixin):
undo = action(getStr("undo"), self.undoShapeEdit,
'Ctrl+Z', "undo", getStr("undo"), enabled=False)
+
+ lock = action(getStr("lockBox"), self.lockSelectedShape,
+ None, "lock", getStr("lockBoxDetail"),
+ enabled=False)
self.editButton.setDefaultAction(edit)
self.newButton.setDefaultAction(create)
@@ -538,13 +547,13 @@ class MainWindow(QMainWindow, WindowMixin):
fitWindow=fitWindow, fitWidth=fitWidth,
zoomActions=zoomActions, saveLabel=saveLabel,
undo=undo, undoLastPoint=undoLastPoint,open_dataset_dir=open_dataset_dir,
- rotateLeft=rotateLeft,rotateRight=rotateRight,
+ rotateLeft=rotateLeft,rotateRight=rotateRight,lock=lock,
fileMenuActions=(
opendir, open_dataset_dir, saveLabel, resetAll, quit),
beginner=(), advanced=(),
editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint,
- None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption),
- beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,),
+ None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption,lock),
+ beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,lock),
advancedContext=(createMode, editMode, edit, copy,
delete, shapeLineColor, shapeFillColor),
onLoadActive=(
@@ -998,6 +1007,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.delete.setEnabled(n_selected)
self.actions.copy.setEnabled(n_selected)
self.actions.edit.setEnabled(n_selected == 1)
+ self.actions.lock.setEnabled(n_selected)
def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked()
@@ -1041,7 +1051,7 @@ class MainWindow(QMainWindow, WindowMixin):
def loadLabels(self, shapes):
s = []
for label, points, line_color, fill_color, difficult in shapes:
- shape = Shape(label=label)
+ shape = Shape(label=label,line_color=line_color)
for x, y in points:
# Ensure the labels are within the bounds of the image. If not, fix them.
@@ -1051,6 +1061,7 @@ class MainWindow(QMainWindow, WindowMixin):
shape.addPoint(QPointF(x, y))
shape.difficult = difficult
+ #shape.locked = False
shape.close()
s.append(shape)
@@ -1063,10 +1074,12 @@ class MainWindow(QMainWindow, WindowMixin):
# shape.fill_color = QColor(*fill_color)
# else:
# shape.fill_color = generateColorByText(label)
-
+
self.addLabel(shape)
+
self.updateComboBox()
self.canvas.loadShapes(s)
+
def singleLabel(self, shape):
if shape is None:
@@ -1106,10 +1119,9 @@ class MainWindow(QMainWindow, WindowMixin):
difficult=s.difficult) # bool
shapes = [] if mode == 'Auto' else \
- [format_shape(shape) for shape in self.canvas.shapes]
+ [format_shape(shape) for shape in self.canvas.shapes if shape.line_color != DEFAULT_LOCK_COLOR]
# Can add differrent annotation formats here
-
- for box in self.result_dic:
+ for box in self.result_dic :
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
if trans_dic["label"] == "" and mode == 'Auto':
continue
@@ -1120,7 +1132,6 @@ class MainWindow(QMainWindow, WindowMixin):
for box in shapes:
trans_dic.append({"transcription": box['label'], "points": box['points'], 'difficult': box['difficult']})
self.PPlabel[annotationFilePath] = trans_dic
-
if mode == 'Auto':
self.Cachelabel[annotationFilePath] = trans_dic
@@ -1313,6 +1324,7 @@ class MainWindow(QMainWindow, WindowMixin):
# unicodeFilePath = os.path.abspath(unicodeFilePath)
# Tzutalin 20160906 : Add file list and dock to move faster
# Highlight the file item
+
if unicodeFilePath and self.fileListWidget.count() > 0:
if unicodeFilePath in self.mImgList:
index = self.mImgList.index(unicodeFilePath)
@@ -1322,6 +1334,7 @@ class MainWindow(QMainWindow, WindowMixin):
###
self.iconlist.clear()
self.additems5(None)
+
for i in range(5):
item_tooltip = self.iconlist.item(i).toolTip()
# print(i,"---",item_tooltip)
@@ -1340,7 +1353,6 @@ class MainWindow(QMainWindow, WindowMixin):
if unicodeFilePath and os.path.exists(unicodeFilePath):
self.canvas.verified = False
-
cvimg = cv2.imdecode(np.fromfile(unicodeFilePath, dtype=np.uint8), 1)
height, width, depth = cvimg.shape
cvimg = cv2.cvtColor(cvimg, cv2.COLOR_BGR2RGB)
@@ -1361,34 +1373,52 @@ class MainWindow(QMainWindow, WindowMixin):
else:
self.dirty = False
self.actions.save.setEnabled(True)
-
+ if len(self.canvas.lockedShapes) != 0:
+ self.actions.save.setEnabled(True)
+ self.setDirty()
self.canvas.setEnabled(True)
self.adjustScale(initial=True)
self.paintCanvas()
self.addRecentFile(self.filePath)
self.toggleActions(True)
+
self.showBoundingBoxFromPPlabel(filePath)
self.setWindowTitle(__appname__ + ' ' + filePath)
-
+
# Default : select last item if there is at least one item
if self.labelList.count():
self.labelList.setCurrentItem(self.labelList.item(self.labelList.count() - 1))
self.labelList.item(self.labelList.count() - 1).setSelected(True)
+ # show file list image count
+ select_indexes = self.fileListWidget.selectedIndexes()
+ if len(select_indexes) > 0:
+ self.filedock.setWindowTitle(self.fileListName + f" ({select_indexes[0].row() + 1}"
+ f"/{self.fileListWidget.count()})")
+
+
self.canvas.setFocus(True)
return True
return False
-
def showBoundingBoxFromPPlabel(self, filePath):
+ width, height = self.image.width(), self.image.height()
imgidx = self.getImglabelidx(filePath)
- if imgidx not in self.PPlabel.keys():
- return
- shapes = []
- for box in self.PPlabel[imgidx]:
- shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
-
+ shapes =[]
+ #box['ratio'] of the shapes saved in lockedShapes contains the ratio of the
+ # four corner coordinates of the shapes to the height and width of the image
+ for box in self.canvas.lockedShapes:
+ if self.canvas.isInTheSameImage:
+ shapes.append((box['transcription'], [[s[0]*width,s[1]*height]for s in box['ratio']],
+ DEFAULT_LOCK_COLOR, None, box['difficult']))
+ else:
+ shapes.append(('锁定框:待检测', [[s[0]*width,s[1]*height]for s in box['ratio']],
+ DEFAULT_LOCK_COLOR, None, box['difficult']))
+ if imgidx in self.PPlabel.keys():
+ for box in self.PPlabel[imgidx]:
+ shapes.append((box['transcription'], box['points'], None, None, box['difficult']))
+
self.loadLabels(shapes)
self.canvas.verified = False
@@ -1576,7 +1606,8 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.rotateLeft.setEnabled(True)
self.actions.rotateRight.setEnabled(True)
-
+ self.fileListWidget.setCurrentRow(0) # set list index to first
+ self.filedock.setWindowTitle(self.fileListName + f" (1/{self.fileListWidget.count()})") # show image count
def openPrevImg(self, _value=False):
if len(self.mImgList) <= 0:
@@ -1646,9 +1677,37 @@ class MainWindow(QMainWindow, WindowMixin):
else:
return fullFilePath
return ''
-
+
+
+ def saveLockedShapes(self):
+ self.canvas.lockedShapes = []
+ self.canvas.selectedShapes = []
+ for s in self.canvas.shapes:
+ if s.line_color == DEFAULT_LOCK_COLOR:
+ self.canvas.selectedShapes.append(s)
+ self.lockSelectedShape()
+ for s in self.canvas.shapes:
+ if s.line_color == DEFAULT_LOCK_COLOR:
+ self.canvas.selectedShapes.remove(s)
+ self.canvas.shapes.remove(s)
+
+
def _saveFile(self, annotationFilePath, mode='Manual'):
+ if len(self.canvas.lockedShapes) != 0:
+ self.saveLockedShapes()
+
if mode == 'Manual':
+ self.result_dic_locked = []
+ img = cv2.imread(self.filePath)
+ width, height = self.image.width(), self.image.height()
+ for shape in self.canvas.lockedShapes:
+ box = [[int(p[0]*width), int(p[1]*height)] for p in shape['ratio']]
+ assert len(box) == 4
+ result = [(shape['transcription'],1)]
+ result.insert(0, box)
+ self.result_dic_locked.append(result)
+ self.result_dic += self.result_dic_locked
+ self.result_dic_locked = []
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath)
@@ -1663,13 +1722,13 @@ class MainWindow(QMainWindow, WindowMixin):
self.savePPlabel(mode='Auto')
self.fileListWidget.insertItem(int(currIndex), item)
- self.openNextImg()
+ if not self.canvas.isInTheSameImage:
+ self.openNextImg()
self.actions.saveRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True)
elif mode == 'Auto':
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
-
self.setClean()
self.statusBar().showMessage('Saved to %s' % annotationFilePath)
self.statusBar().show()
@@ -1733,14 +1792,19 @@ class MainWindow(QMainWindow, WindowMixin):
if discardChanges == QMessageBox.No:
return True
elif discardChanges == QMessageBox.Yes:
+ self.canvas.isInTheSameImage = True
self.saveFile()
+ self.canvas.isInTheSameImage = False
return True
else:
return False
def discardChangesDialog(self):
yes, no, cancel = QMessageBox.Yes, QMessageBox.No, QMessageBox.Cancel
- msg = u'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.'
+ if self.lang == 'ch':
+ msg = u'您有未保存的变更, 您想保存再继续吗?\n点击 "No" 丢弃所有未保存的变更.'
+ else:
+ msg = u'You have unsaved changes, would you like to save them and proceed?\nClick "No" to undo all changes.'
return QMessageBox.warning(self, u'Attention', msg, yes | no | cancel)
def errorMessage(self, title, message):
@@ -1858,7 +1922,7 @@ class MainWindow(QMainWindow, WindowMixin):
uncheckedList = [i for i in self.mImgList if i not in self.fileStatedict.keys()]
self.autoDialog = AutoDialog(parent=self, ocr=self.ocr, mImgList=uncheckedList, lenbar=len(uncheckedList))
self.autoDialog.popUp()
- self.currIndex=len(self.mImgList)
+ self.currIndex = len(self.mImgList) - 1
self.loadFile(self.filePath) # ADD
self.haveAutoReced = True
self.AutoRecognition.setEnabled(False)
@@ -1872,6 +1936,7 @@ class MainWindow(QMainWindow, WindowMixin):
# org_box = [dic['points'] for dic in self.PPlabel[self.getImglabelidx(self.filePath)]]
if self.canvas.shapes:
self.result_dic = []
+ self.result_dic_locked = [] # result_dic_locked stores the ocr result of self.canvas.lockedShapes
rec_flag = 0
for shape in self.canvas.shapes:
box = [[int(p.x()), int(p.y())] for p in shape.points]
@@ -1883,21 +1948,32 @@ class MainWindow(QMainWindow, WindowMixin):
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] != '':
- result.insert(0, box)
- print('result in reRec is ', result)
- self.result_dic.append(result)
+ if shape.line_color == DEFAULT_LOCK_COLOR:
+ shape.label = result[0][0]
+ result.insert(0, box)
+ self.result_dic_locked.append(result)
+ else:
+ result.insert(0, box)
+ self.result_dic.append(result)
else:
print('Can not recognise the box')
- self.result_dic.append([box,(self.noLabelText,0)])
-
- if self.noLabelText == shape.label or result[1][0] == shape.label:
- print('label no change')
- else:
- rec_flag += 1
-
- if len(self.result_dic) > 0 and rec_flag > 0:
+ if shape.line_color == DEFAULT_LOCK_COLOR:
+ shape.label = result[0][0]
+ self.result_dic_locked.append([box,(self.noLabelText,0)])
+ else:
+ self.result_dic.append([box,(self.noLabelText,0)])
+ try:
+ if self.noLabelText == shape.label or result[1][0] == shape.label:
+ print('label no change')
+ else:
+ rec_flag += 1
+ except IndexError as e:
+ print('Can not recognise the box')
+ if (len(self.result_dic) > 0 and rec_flag > 0)or self.canvas.lockedShapes:
+ self.canvas.isInTheSameImage = True
self.saveFile(mode='Auto')
self.loadFile(self.filePath)
+ self.canvas.isInTheSameImage = False
self.setDirty()
elif len(self.result_dic) == len(self.canvas.shapes) and rec_flag == 0:
QMessageBox.information(self, "Information", "The recognition result remains unchanged!")
@@ -2027,8 +2103,11 @@ class MainWindow(QMainWindow, WindowMixin):
f.write(key + '\t')
f.write(json.dumps(self.PPlabel[key], ensure_ascii=False) + '\n')
- if mode=='Manual':
- msg = 'Images that have been checked are saved in '+ self.PPlabelpath
+ if mode == 'Manual':
+ if self.lang == 'ch':
+ msg = '已将检查过的图片标签保存在 ' + self.PPlabelpath + " 文件中"
+ else:
+ msg = 'Images that have been checked are saved in ' + self.PPlabelpath
QMessageBox.information(self, "Information", msg)
def saveCacheLabel(self):
@@ -2107,6 +2186,44 @@ class MainWindow(QMainWindow, WindowMixin):
self.labelList.clearSelection()
self._noSelectionSlot = False
self.canvas.loadShapes(shapes, replace=replace)
+ print("loadShapes")#1
+
+
+ def lockSelectedShape(self):
+ """lock the selsected shapes.
+
+ Add self.selectedShapes to lock self.canvas.lockedShapes,
+ which holds the ratio of the four coordinates of the locked shapes
+ to the width and height of the image
+ """
+ width, height = self.image.width(), self.image.height()
+ def format_shape(s):
+ return dict(label=s.label, # str
+ line_color=s.line_color.getRgb(),
+ fill_color=s.fill_color.getRgb(),
+ ratio=[[int(p.x())/width, int(p.y())/height] for p in s.points], # QPonitF
+ # add chris
+ difficult=s.difficult) # bool
+ #lock
+ if len(self.canvas.lockedShapes) == 0:
+ for s in self.canvas.selectedShapes:
+ s.line_color = DEFAULT_LOCK_COLOR
+ s.locked = True
+ shapes = [format_shape(shape) for shape in self.canvas.selectedShapes]
+ trans_dic = []
+ for box in shapes:
+ trans_dic.append({"transcription": box['label'], "ratio": box['ratio'], 'difficult': box['difficult']})
+ self.canvas.lockedShapes = trans_dic
+ self.actions.save.setEnabled(True)
+
+ #unlock
+ else:
+ for s in self.canvas.shapes:
+ s.line_color = DEFAULT_LINE_COLOR
+ self.canvas.lockedShapes = []
+ self.result_dic_locked = []
+ self.setDirty()
+ self.actions.save.setEnabled(True)
def inverted(color):
diff --git a/PPOCRLabel/README.md b/PPOCRLabel/README.md
index e8634ef8c06feae1f0adffb22c5694084dab78cd..10bfa4699d0141c94131a7cb5b4860f7a1edd03f 100644
--- a/PPOCRLabel/README.md
+++ b/PPOCRLabel/README.md
@@ -143,7 +143,7 @@ python PPOCRLabel.py
### 3.1 Shortcut keys
| Shortcut keys | Description |
-| ------------------------ | ------------------------------------------------ |
+|--------------------------| ------------------------------------------------ |
| Ctrl + Shift + R | Re-recognize all the labels of the current image |
| W | Create a rect box |
| Q | Create a four-points box |
@@ -151,7 +151,7 @@ python PPOCRLabel.py
| Ctrl + R | Re-recognize the selected box |
| Ctrl + C | Copy and paste the selected box |
| Ctrl + Left Mouse Button | Multi select the label box |
-| Backspace | Delete the selected box |
+| Ctrl + X | Delete the selected box |
| Ctrl + V | Check image |
| Ctrl + Shift + d | Delete image |
| D | Next image |
diff --git a/PPOCRLabel/README_ch.md b/PPOCRLabel/README_ch.md
index e1c391bc8637baa4adfa8852d805ed0f4bf04d6d..03e1b2da2eac7c71f029ce612a89fb4e8ccae993 100644
--- a/PPOCRLabel/README_ch.md
+++ b/PPOCRLabel/README_ch.md
@@ -131,16 +131,16 @@ python PPOCRLabel.py --lang ch
### 3.1 快捷键
-| 快捷键 | 说明 |
-| ---------------- | ---------------------------- |
+| 快捷键 | 说明 |
+|------------------| ---------------------------- |
| Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 |
| Q | 新建四点框 |
| Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 |
| Ctrl + C | 复制并粘贴选中的标记框 |
-| Ctrl + 鼠标左键 | 多选标记框 |
-| Backspace | 删除所选框 |
+| Ctrl + 鼠标左键 | 多选标记框 |
+| Ctrl + X | 删除所选框 |
| Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 |
diff --git a/PPOCRLabel/libs/autoDialog.py b/PPOCRLabel/libs/autoDialog.py
index 3374e92cc587baa7e8bab5c7d8e8dc34eb6366b6..189a590de851228e08d71f1dd2c00c823b9c2b0c 100644
--- a/PPOCRLabel/libs/autoDialog.py
+++ b/PPOCRLabel/libs/autoDialog.py
@@ -6,6 +6,8 @@ except ImportError:
from PyQt4.QtGui import *
from PyQt4.QtCore import *
+import time
+import datetime
import json
import cv2
import numpy as np
@@ -80,8 +82,9 @@ class AutoDialog(QDialog):
self.parent = parent
self.ocr = ocr
self.mImgList = mImgList
+ self.lender = lenbar
self.pb = QProgressBar()
- self.pb.setRange(0, lenbar)
+ self.pb.setRange(0, self.lender)
self.pb.setValue(0)
layout = QVBoxLayout()
@@ -108,10 +111,16 @@ class AutoDialog(QDialog):
self.thread_1.progressBarValue.connect(self.handleProgressBarSingal)
self.thread_1.listValue.connect(self.handleListWidgetSingal)
self.thread_1.endsignal.connect(self.handleEndsignalSignal)
+ self.time_start = time.time() # save start time
def handleProgressBarSingal(self, i):
self.pb.setValue(i)
+ # calculate time left of auto labeling
+ avg_time = (time.time() - self.time_start) / i # Use average time to prevent time fluctuations
+ time_left = str(datetime.timedelta(seconds=avg_time * (self.lender - i))).split(".")[0] # Remove microseconds
+ self.setWindowTitle("PPOCRLabel -- " + f"Time Left: {time_left}") # show
+
def handleListWidgetSingal(self, i):
self.listWidget.addItem(i)
titem = self.listWidget.item(self.listWidget.count() - 1)
diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py
index 6ac1f28b85e65c3776d310136352b70c45628db6..6116f357d6efb91a5a9d9cdc6ba757fbd06df60e 100644
--- a/PPOCRLabel/libs/canvas.py
+++ b/PPOCRLabel/libs/canvas.py
@@ -87,6 +87,10 @@ class Canvas(QWidget):
#initialisation for panning
self.pan_initial_pos = QPoint()
+ #lockedshapes related
+ self.lockedShapes = []
+ self.isInTheSameImage = False
+
def setDrawingColor(self, qColor):
self.drawingLineColor = qColor
self.drawingRectColor = qColor
diff --git a/PPOCRLabel/libs/shape.py b/PPOCRLabel/libs/shape.py
index ef8e09be061927d39403cc0cdc0727fff69854a7..e2cdcb322790c9b6edd3c504405ad65097a7bc49 100644
--- a/PPOCRLabel/libs/shape.py
+++ b/PPOCRLabel/libs/shape.py
@@ -30,6 +30,7 @@ DEFAULT_SELECT_LINE_COLOR = QColor(255, 255, 255)
DEFAULT_SELECT_FILL_COLOR = QColor(0, 128, 255, 155)
DEFAULT_VERTEX_FILL_COLOR = QColor(0, 255, 0, 255)
DEFAULT_HVERTEX_FILL_COLOR = QColor(255, 0, 0)
+DEFAULT_LOCK_COLOR = QColor(255, 0, 255)
MIN_Y_LABEL = 10
@@ -57,7 +58,7 @@ class Shape(object):
self.selected = False
self.difficult = difficult
self.paintLabel = paintLabel
-
+ self.locked = False
self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX
self._highlightSettings = {
diff --git a/PPOCRLabel/resources/icons/lock.png b/PPOCRLabel/resources/icons/lock.png
new file mode 100644
index 0000000000000000000000000000000000000000..f4d50d70b43ae91cdb60ddc73ffa5385e6253ea1
Binary files /dev/null and b/PPOCRLabel/resources/icons/lock.png differ
diff --git a/PPOCRLabel/resources/strings/strings-en.properties b/PPOCRLabel/resources/strings/strings-en.properties
index 70036e6560c53bd35e2c5f29b6912092701fe4ae..f59e43aa92ff9ccd04686e9c16db181983b57b2c 100644
--- a/PPOCRLabel/resources/strings/strings-en.properties
+++ b/PPOCRLabel/resources/strings/strings-en.properties
@@ -104,4 +104,6 @@ singleRe=Re-recognition RectBox
labelDialogOption=Pop-up Label Input Dialog
undo=Undo
undoLastPoint=Undo Last Point
-autoSaveMode=Auto Export Label Mode
\ No newline at end of file
+autoSaveMode=Auto Export Label Mode
+lockBox=Lock selected box/Unlock all box
+lockBoxDetail=Lock selected box/Unlock all box
\ No newline at end of file
diff --git a/PPOCRLabel/resources/strings/strings-zh-CN.properties b/PPOCRLabel/resources/strings/strings-zh-CN.properties
index 1cd4da7611c72cf37f9c3febe57522f2e38c7f9b..d8bd9d4bff02748397d7a57a6205e67ff69779c2 100644
--- a/PPOCRLabel/resources/strings/strings-zh-CN.properties
+++ b/PPOCRLabel/resources/strings/strings-zh-CN.properties
@@ -104,4 +104,6 @@ singleRe=重识别此区块
labelDialogOption=弹出标记输入框
undo=撤销
undoLastPoint=撤销上个点
-autoSaveMode=自动导出标记结果
\ No newline at end of file
+autoSaveMode=自动导出标记结果
+lockBox=锁定框/解除锁定框
+lockBoxDetail=若当前没有框处于锁定状态则锁定选中的框,若存在锁定框则解除所有锁定框的锁定状态
diff --git a/README.md b/README.md
index 8002e2b03ef63afbeb4de435b8ce1960375c2bd5..8936fbaa27c92fc64a7098a9e79cc0fe923910fb 100644
--- a/README.md
+++ b/README.md
@@ -33,17 +33,17 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
- [more](./doc/doc_en/update_en.md)
## Features
-- PP-OCR series of high-quality pre-trained models, comparable to commercial effects
+- PP-OCR - A series of high-quality pre-trained models, comparable to commercial products
- Ultra lightweight PP-OCRv2 series models: detection (3.1M) + direction classifier (1.4M) + recognition 8.5M) = 13.0M
- Ultra lightweight PP-OCR mobile series models: detection (3.0M) + direction classifier (1.4M) + recognition (5.0M) = 9.4M
- General PP-OCR server series models: detection (47.1M) + direction classifier (1.4M) + recognition (94.9M) = 143.4M
- Support Chinese, English, and digit recognition, vertical text recognition, and long text recognition
- - Support multi-language recognition: about 80 languages like Korean, Japanese, German, French, etc
+ - Support multi-lingual recognition: about 80 languages like Korean, Japanese, German, French, etc
- PP-Structure: a document structurize system
- - support layout analysis and table recognition (support export to Excel)
- - support key information extraction
- - support DocVQA
-- Rich toolkits related to the OCR areas
+ - Support layout analysis and table recognition (support export to Excel)
+ - Support key information extraction
+ - Support DocVQA
+- Rich OCR toolkit
- Semi-automatic data annotation tool, i.e., PPOCRLabel: support fast and efficient data annotation
- Data synthesis tool, i.e., Style-Text: easy to synthesize a large number of images which are similar to the target scene image
- Support user-defined training, provides rich predictive inference deployment solutions
@@ -62,7 +62,7 @@ The above pictures are the visualizations of the general ppocr_server model. For
## Community
-- Scan the QR code below with your Wechat, you can access to official technical exchange group. Look forward to your participation.
+- Scan the QR code below with your Wechat, you can join the official technical discussion group. Looking forward to your participation.
@@ -18,11 +18,11 @@ PaddleOCR contains rich text detection, text recognition and end-to-end algorith
# Recommend
git clone https://github.com/PaddlePaddle/PaddleOCR
-# If you cannot pull successfully due to network problems, you can also choose to use the code hosting on the cloud:
+# If you cannot pull successfully due to network problems, you can switch to the mirror hosted on Gitee:
git clone https://gitee.com/paddlepaddle/PaddleOCR
-# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method.
+# Note: The mirror on Gitee may not keep in synchronization with the latest project on GitHub. There might be a delay of 3-5 days. Please try GitHub at first.
```
### **2.2 Install third-party libraries**
@@ -34,6 +34,6 @@ pip3 install -r requirements.txt
If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows.
-Please try to download Shapely whl file using [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
+Please try to download Shapely whl file from [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
-Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
\ No newline at end of file
+Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md
index e176a1260c734974e2dad843faeb3e5532176629..c7cb3221ccfd897e2fd9062a828c2fe0ceb42024 100644
--- a/doc/doc_en/pgnet_en.md
+++ b/doc/doc_en/pgnet_en.md
@@ -6,18 +6,18 @@
## 1. Brief Introduction
-OCR algorithm can be divided into two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm gets the detection box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm can complete text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster.
+OCR algorithms can be divided into two categories: two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm locates the box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm combines text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster.
### Introduction Of PGNet Algorithm
-In recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has the advantages that other algorithms do not
-- Pgnet loss is designed to guide training, and no character-level annotations is needed
-- NMS and ROI related operations are not needed, It can accelerate the prediction
+During the recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has some advantages over the other algorithms.
+- PGNet loss is designed to guide training, and no character-level annotations is needed.
+- NMS and ROI related operations are not needed. It can accelerate the prediction
- The reading order prediction module is proposed
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition
- Higher accuracy and faster prediction speed
-For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows:
+For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf). The schematic diagram of the algorithm is as follows:
![](../pgnet_framework.png)
-After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
+After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text center-line prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
The results of detection and recognition are as follows:
@@ -40,7 +40,7 @@ Please refer to [Operation Environment Preparation](./environment_en.md) to conf
## 3. Quick Use
-### inference model download
+### Inference model download
This section takes the trained end-to-end model as an example to quickly use the model prediction. First, download the trained end-to-end inference model [download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
```
mkdir inference && cd inference
@@ -131,7 +131,7 @@ python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0
```
#### Load trained model and continue training
-If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
+If you would like to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```
diff --git a/doc/doc_en/training_en.md b/doc/doc_en/training_en.md
index d013f5ac706a2a2b4a5b58ba0a6dff09ab0b4654..1a3165d0ab226d7cbeef356ee750594c759cfe23 100644
--- a/doc/doc_en/training_en.md
+++ b/doc/doc_en/training_en.md
@@ -12,15 +12,15 @@
* [4. FAQ](#3-faq)
-This article will introduce the basic concepts that need to be mastered during model training and the tuning methods during training.
+This article will introduce the basic concepts that is necessary for model training and tuning.
-At the same time, it will briefly introduce the components of the PaddleOCR model training data and how to prepare the data finetune model in the vertical scene.
+At the same time, it will briefly introduce the structure of the training data and how to prepare the data to fine-tune model in vertical scenes.
## 1. Yml Configuration
-The PaddleOCR model uses configuration files to manage network training and evaluation parameters. In the configuration file, you can set the model, optimizer, loss function, and pre- and post-processing parameters of the model. PaddleOCR reads these parameters from the configuration file, and then builds a complete training process to complete the model training. When optimized, the configuration can be completed by modifying the parameters in the configuration file, which is simple to use and convenient to modify.
+The PaddleOCR uses configuration files to control network training and evaluation parameters. In the configuration file, you can set the model, optimizer, loss function, and pre- and post-processing parameters of the model. PaddleOCR reads these parameters from the configuration file, and then builds a complete training process to train the model. Fine-tuning can also be completed by modifying the parameters in the configuration file, which is simple and convenient.
For the complete configuration file description, please refer to [Configuration File](./config_en.md)
@@ -28,13 +28,13 @@ For the complete configuration file description, please refer to [Configuration
## 2. Basic Concepts
-In the process of model training, some hyperparameters need to be manually adjusted to help the model obtain the optimal index at the least loss. Different data volumes may require different hyper-parameters. When you want to finetune your own data or tune the model effect, there are several parameter adjustment strategies for reference:
+During the model training process, some hyper-parameters can be manually specified to obtain the optimal result at the least cost. Different data volumes may require different hyper-parameters. When you want to fine-tune the model based on your own data, there are several parameter adjustment strategies for reference:
### 2.1 Learning Rate
-The learning rate is one of the important hyperparameters for training neural networks. It represents the step length of the gradient moving to the optimal solution of the loss function in each iteration.
-A variety of learning rate update strategies are provided in PaddleOCR, which can be modified through configuration files, for example:
+The learning rate is one of the most important hyper-parameters for training neural networks. It represents the step length of the gradient moving towards the optimal solution of the loss function in each iteration.
+A variety of learning rate update strategies are provided by PaddleOCR, which can be specified in configuration files. For example,
```
Optimizer:
@@ -46,16 +46,15 @@ Optimizer:
warmup_epoch: 5
```
-Piecewise stands for piecewise constant attenuation. Different learning rates are specified in different learning stages,
-and the learning rate is the same in each stage.
+`Piecewise` stands for piece-wise constant attenuation. Different learning rates are specified in different learning stages, and the learning rate stay the same in each stage.
-warmup_epoch means that in the first 5 epochs, the learning rate will gradually increase from 0 to base_lr. For all strategies, please refer to the code [learning_rate.py](../../ppocr/optimizer/learning_rate.py).
+`warmup_epoch` means that in the first 5 epochs, the learning rate will be increased gradually from 0 to base_lr. For all strategies, please refer to the code [learning_rate.py](../../ppocr/optimizer/learning_rate.py).
### 2.2 Regularization
-Regularization can effectively avoid algorithm overfitting. PaddleOCR provides L1 and L2 regularization methods.
-L1 and L2 regularization are the most commonly used regularization methods.
+Regularization can effectively avoid algorithm over-fitting. PaddleOCR provides L1 and L2 regularization methods.
+L1 and L2 regularization are the most widely used regularization methods.
L1 regularization adds a regularization term to the objective function to reduce the sum of absolute values of the parameters;
while in L2 regularization, the purpose of adding a regularization term is to reduce the sum of squared parameters.
The configuration method is as follows:
@@ -95,7 +94,7 @@ The current open source models, data sets and magnitudes are as follows:
- Chinese data set, LSVT street view data set crops the image according to the truth value, and performs position calibration, a total of 30w images. In addition, based on the LSVT corpus, 500w of synthesized data.
- Small language data set, using different corpora and fonts, respectively generated 100w synthetic data set, and using ICDAR-MLT as the verification set.
-Among them, the public data sets are all open source, users can search and download by themselves, or refer to [Chinese data set](./datasets.md), synthetic data is not open source, users can use open source synthesis tools to synthesize by themselves. Synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) etc.
+Among them, the public data sets are all open source, users can search and download by themselves, or refer to [Chinese data set](../doc_ch/datasets.md), synthetic data is not open source, users can use open source synthesis tools to synthesize by themselves. Synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) etc.
@@ -129,17 +128,17 @@ There are several experiences for reference when constructing the data set:
**Q**: How to choose a suitable network input shape when training CRNN recognition?
A: The general height is 32, the longest width is selected, there are two methods:
-
+
(1) Calculate the aspect ratio distribution of training sample images. The selection of the maximum aspect ratio considers 80% of the training samples.
-
+
(2) Count the number of texts in training samples. The selection of the longest number of characters considers the training sample that satisfies 80%. Then the aspect ratio of Chinese characters is approximately considered to be 1, and that of English is 3:1, and the longest width is estimated.
**Q**: During the recognition training, the accuracy of the training set has reached 90, but the accuracy of the verification set has been kept at 70, what should I do?
A: If the accuracy of the training set is 90 and the test set is more than 70, it should be over-fitting. There are two methods to try:
-
+
(1) Add more augmentation methods or increase the [probability] of augmented prob (https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppocr/data/imaug/rec_img_aug.py#L341), The default is 0.4.
-
+
(2) Increase the [l2 dcay value] of the system (https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47)
**Q**: When the recognition model is trained, loss can drop normally, but acc is always 0
diff --git a/doc/doc_en/update_en.md b/doc/doc_en/update_en.md
index 6a95b5be279d7a0b8a204cadd46b283b5eb26690..39fd936d1bd4e5f8d8535805f865792820ee1199 100644
--- a/doc/doc_en/update_en.md
+++ b/doc/doc_en/update_en.md
@@ -5,7 +5,7 @@
- 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files).
- 2021.4.8 release end-to-end text recognition algorithm [PGNet](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) which is published in AAAI 2021. Find tutorial [here](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/pgnet_en.md);release multi language recognition [models](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md), support more than 80 languages recognition; especically, the performance of [English recognition model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/models_list_en.md#English) is Optimized.
-- 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
+- 2021.1.21 update more than 25+ multilingual recognition models [models list](./models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](../../StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](../../PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
diff --git a/ppstructure/vqa/images/input/zh_val_0.jpg b/doc/vqa/input/zh_val_0.jpg
similarity index 100%
rename from ppstructure/vqa/images/input/zh_val_0.jpg
rename to doc/vqa/input/zh_val_0.jpg
diff --git a/ppstructure/vqa/images/input/zh_val_21.jpg b/doc/vqa/input/zh_val_21.jpg
similarity index 100%
rename from ppstructure/vqa/images/input/zh_val_21.jpg
rename to doc/vqa/input/zh_val_21.jpg
diff --git a/ppstructure/vqa/images/input/zh_val_40.jpg b/doc/vqa/input/zh_val_40.jpg
similarity index 100%
rename from ppstructure/vqa/images/input/zh_val_40.jpg
rename to doc/vqa/input/zh_val_40.jpg
diff --git a/ppstructure/vqa/images/input/zh_val_42.jpg b/doc/vqa/input/zh_val_42.jpg
similarity index 100%
rename from ppstructure/vqa/images/input/zh_val_42.jpg
rename to doc/vqa/input/zh_val_42.jpg
diff --git a/ppstructure/vqa/images/result_re/zh_val_21_re.jpg b/doc/vqa/result_re/zh_val_21_re.jpg
similarity index 100%
rename from ppstructure/vqa/images/result_re/zh_val_21_re.jpg
rename to doc/vqa/result_re/zh_val_21_re.jpg
diff --git a/ppstructure/vqa/images/result_re/zh_val_40_re.jpg b/doc/vqa/result_re/zh_val_40_re.jpg
similarity index 100%
rename from ppstructure/vqa/images/result_re/zh_val_40_re.jpg
rename to doc/vqa/result_re/zh_val_40_re.jpg
diff --git a/ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg b/doc/vqa/result_ser/zh_val_0_ser.jpg
similarity index 100%
rename from ppstructure/vqa/images/result_ser/zh_val_0_ser.jpg
rename to doc/vqa/result_ser/zh_val_0_ser.jpg
diff --git a/ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg b/doc/vqa/result_ser/zh_val_42_ser.jpg
similarity index 100%
rename from ppstructure/vqa/images/result_ser/zh_val_42_ser.jpg
rename to doc/vqa/result_ser/zh_val_42_ser.jpg
diff --git "a/notebook/notebook_ch/5.ppocrv2_inference_deployment/PP-OCRv2\351\242\204\346\265\213\351\203\250\347\275\262\345\256\236\346\210\230.ipynb" "b/notebook/notebook_ch/5.ppocrv2_inference_deployment/PP-OCRv2\351\242\204\346\265\213\351\203\250\347\275\262\345\256\236\346\210\230.ipynb"
index 11626518d5a8e1a6b62227cbdf81d50ce2b0eee5..400f93c257356e45b7c0bfeb1cc0e9109b9d85be 100644
--- "a/notebook/notebook_ch/5.ppocrv2_inference_deployment/PP-OCRv2\351\242\204\346\265\213\351\203\250\347\275\262\345\256\236\346\210\230.ipynb"
+++ "b/notebook/notebook_ch/5.ppocrv2_inference_deployment/PP-OCRv2\351\242\204\346\265\213\351\203\250\347\275\262\345\256\236\346\210\230.ipynb"
@@ -2551,7 +2551,7 @@
"\n",
"Paddle Serving是飞桨为方便开发者进行服务化部署而打造的工具,本节主要介绍基于Paddle Serving的PP-OCRv2系统服务化部署过程。\n",
"\n",
- "## 4.1 Padde Serving简介\n",
+ "## 4.1 Paddle Serving简介\n",
"\n",
"Paddle Serving作为飞桨(PaddlePaddle)开源的服务化部署框架,长期目标就是围绕着人工智能落地的最后一公里提供越来越专业、可靠、易用的服务。Paddle Serving目前提供了两套框架C++ Serving和Python Pipeline。Python Pipeline框架倾向于二次开发的便捷性,C++ Serving框架更倾向于追求极致性能。\n",
"\n",
diff --git a/paddleocr.py b/paddleocr.py
index 733c83d1b4faa23212e7186148a5a9e1154ba891..f0938c6740606bdb2a96a6f9836602c0fb670650 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -42,12 +42,14 @@ __all__ = [
]
SUPPORT_DET_MODEL = ['DB']
-VERSION = '2.3.0.2'
+VERSION = '2.4'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
+SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2']
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
+SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE']
MODEL_URLS = {
'OCR': {
'PP-OCRv2': {
@@ -190,6 +192,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--ocr_version",
type=str,
+ choices=SUPPORT_OCR_MODEL_VERSION,
default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. '
@@ -198,6 +201,7 @@ def parse_args(mMain=True):
parser.add_argument(
"--structure_version",
type=str,
+ choices=SUPPORT_STRUCTURE_MODEL_VERSION,
default='STRUCTURE',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
@@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang):
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError
+
model_urls = MODEL_URLS[type]
if version not in model_urls:
- logger.warning('version {} not in {}, auto switch to version {}'.format(
- version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
- logger.warning(
- 'version {} not support {} models, auto switch to version {}'.
- format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error('{} models is not support, we only support {}'.format(
model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1)
+
if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
- logger.warning(
- 'lang {} is not support in {}, auto switch to version {}'.
- format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error(
@@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
+ assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(
+ SUPPORT_OCR_MODEL_VERSION, params.ocr_version)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
@@ -347,8 +347,9 @@ class PaddleOCR(predict_system.TextSystem):
ocr with paddleocr
args:
img: img for ocr, support ndarray, img_path and list or ndarray
- det: use text detection or not, if false, only rec will be exec. default is True
- rec: use text recognition or not, if false, only det will be exec. default is True
+ det: use text detection or not. If false, only rec will be exec. Default is True
+ rec: use text recognition or not. If false, only det will be exec. Default is True
+ cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
"""
assert isinstance(img, (np.ndarray, list, str))
if isinstance(img, list) and det == True:
@@ -398,6 +399,8 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
+ assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
+ SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 0bb3d506483a331fba48feafeff9ca2d439f3782..60ab7bd0b4ceab846982c8744d5b277ee17185df 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -20,6 +20,7 @@ from __future__ import unicode_literals
import os
import sys
import numpy as np
+import skimage
import paddle
import signal
import random
@@ -86,13 +87,19 @@ def build_dataloader(config, mode, device, logger, seed=None):
shuffle=shuffle,
drop_last=drop_last)
+ if 'collate_fn' in loader_config:
+ from . import collate_fn
+ collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
+ else:
+ collate_fn = None
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
- use_shared_memory=use_shared_memory)
+ use_shared_memory=use_shared_memory,
+ collate_fn=collate_fn)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
diff --git a/ppstructure/vqa/data_collator.py b/ppocr/data/collate_fn.py
similarity index 59%
rename from ppstructure/vqa/data_collator.py
rename to ppocr/data/collate_fn.py
index a969935b487e3d22ea5c4a3527028aa2cfe1a797..89c6b4fd5ae151e1d703ea5c59abf0177dfc3a8b 100644
--- a/ppstructure/vqa/data_collator.py
+++ b/ppocr/data/collate_fn.py
@@ -15,20 +15,20 @@
import paddle
import numbers
import numpy as np
+from collections import defaultdict
-class DataCollator:
+class DictCollator(object):
"""
data batch
"""
def __call__(self, batch):
- data_dict = {}
+ # todo:support batch operators
+ data_dict = defaultdict(list)
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
- if k not in data_dict:
- data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
@@ -36,3 +36,23 @@ class DataCollator:
for k in to_tensor_keys:
data_dict[k] = paddle.to_tensor(data_dict[k])
return data_dict
+
+
+class ListCollator(object):
+ """
+ data batch
+ """
+
+ def __call__(self, batch):
+ # todo:support batch operators
+ data_dict = defaultdict(list)
+ to_tensor_idxs = []
+ for sample in batch:
+ for idx, v in enumerate(sample):
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
+ if idx not in to_tensor_idxs:
+ to_tensor_idxs.append(idx)
+ data_dict[idx].append(v)
+ for idx in to_tensor_idxs:
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
+ return list(data_dict.values())
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 5aaa1cd71eb791efa94e6bd812f3ab76632c96c6..90a70875b9def5a1300e26dec277e888235f8237 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -34,6 +34,8 @@ from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
+from .vqa import *
+
def transform(data, ops=None):
""" transform """
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index f83255b732f5990de6a99d1149bab77d682c85b3..786647f1f655dd40be1117df912f59c42108539e 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
+import copy
import numpy as np
import string
from shapely.geometry import LineString, Point, Polygon
@@ -736,7 +737,7 @@ class TableLabelEncode(object):
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
- % char_or_elem
+ % char_or_elem
return idx
@@ -782,3 +783,176 @@ class SARLabelEncode(BaseRecLabelEncode):
def get_ignored_tokens(self):
return [self.padding_idx]
+
+
+class VQATokenLabelEncode(object):
+ """
+ Label encode for NLP VQA methods
+ """
+
+ def __init__(self,
+ class_path,
+ contains_re=False,
+ add_special_ids=False,
+ algorithm='LayoutXLM',
+ infer_mode=False,
+ ocr_engine=None,
+ **kwargs):
+ super(VQATokenLabelEncode, self).__init__()
+ from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer
+ from ppocr.utils.utility import load_vqa_bio_label_maps
+ tokenizer_dict = {
+ 'LayoutXLM': {
+ 'class': LayoutXLMTokenizer,
+ 'pretrained_model': 'layoutxlm-base-uncased'
+ },
+ 'LayoutLM': {
+ 'class': LayoutLMTokenizer,
+ 'pretrained_model': 'layoutlm-base-uncased'
+ }
+ }
+ self.contains_re = contains_re
+ tokenizer_config = tokenizer_dict[algorithm]
+ self.tokenizer = tokenizer_config['class'].from_pretrained(
+ tokenizer_config['pretrained_model'])
+ self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
+ self.add_special_ids = add_special_ids
+ self.infer_mode = infer_mode
+ self.ocr_engine = ocr_engine
+
+ def __call__(self, data):
+ # load bbox and label info
+ ocr_info = self._load_ocr_info(data)
+
+ height, width, _ = data['image'].shape
+
+ words_list = []
+ bbox_list = []
+ input_ids_list = []
+ token_type_ids_list = []
+ segment_offset_id = []
+ gt_label_list = []
+
+ entities = []
+
+ # for re
+ train_re = self.contains_re and not self.infer_mode
+ if train_re:
+ relations = []
+ id2label = {}
+ entity_id_to_index_map = {}
+ empty_entity = set()
+
+ data['ocr_info'] = copy.deepcopy(ocr_info)
+
+ for info in ocr_info:
+ if train_re:
+ # for re
+ if len(info["text"]) == 0:
+ empty_entity.add(info["id"])
+ continue
+ id2label[info["id"]] = info["label"]
+ relations.extend([tuple(sorted(l)) for l in info["linking"]])
+ # smooth_box
+ bbox = self._smooth_box(info["bbox"], height, width)
+
+ text = info["text"]
+ encode_res = self.tokenizer.encode(
+ text, pad_to_max_seq_len=False, return_attention_mask=True)
+
+ if not self.add_special_ids:
+ # TODO: use tok.all_special_ids to remove
+ encode_res["input_ids"] = encode_res["input_ids"][1:-1]
+ encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
+ -1]
+ encode_res["attention_mask"] = encode_res["attention_mask"][1:
+ -1]
+ # parse label
+ if not self.infer_mode:
+ label = info['label']
+ gt_label = self._parse_label(label, encode_res)
+
+ # construct entities for re
+ if train_re:
+ if gt_label[0] != self.label2id_map["O"]:
+ entity_id_to_index_map[info["id"]] = len(entities)
+ label = label.upper()
+ entities.append({
+ "start": len(input_ids_list),
+ "end":
+ len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": label.upper(),
+ })
+ else:
+ entities.append({
+ "start": len(input_ids_list),
+ "end": len(input_ids_list) + len(encode_res["input_ids"]),
+ "label": 'O',
+ })
+ input_ids_list.extend(encode_res["input_ids"])
+ token_type_ids_list.extend(encode_res["token_type_ids"])
+ bbox_list.extend([bbox] * len(encode_res["input_ids"]))
+ words_list.append(text)
+ segment_offset_id.append(len(input_ids_list))
+ if not self.infer_mode:
+ gt_label_list.extend(gt_label)
+
+ data['input_ids'] = input_ids_list
+ data['token_type_ids'] = token_type_ids_list
+ data['bbox'] = bbox_list
+ data['attention_mask'] = [1] * len(input_ids_list)
+ data['labels'] = gt_label_list
+ data['segment_offset_id'] = segment_offset_id
+ data['tokenizer_params'] = dict(
+ padding_side=self.tokenizer.padding_side,
+ pad_token_type_id=self.tokenizer.pad_token_type_id,
+ pad_token_id=self.tokenizer.pad_token_id)
+ data['entities'] = entities
+
+ if train_re:
+ data['relations'] = relations
+ data['id2label'] = id2label
+ data['empty_entity'] = empty_entity
+ data['entity_id_to_index_map'] = entity_id_to_index_map
+ return data
+
+ def _load_ocr_info(self, data):
+ def trans_poly_to_bbox(poly):
+ x1 = np.min([p[0] for p in poly])
+ x2 = np.max([p[0] for p in poly])
+ y1 = np.min([p[1] for p in poly])
+ y2 = np.max([p[1] for p in poly])
+ return [x1, y1, x2, y2]
+
+ if self.infer_mode:
+ ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
+ ocr_info = []
+ for res in ocr_result:
+ ocr_info.append({
+ "text": res[1][0],
+ "bbox": trans_poly_to_bbox(res[0]),
+ "poly": res[0],
+ })
+ return ocr_info
+ else:
+ info = data['label']
+ # read text info
+ info_dict = json.loads(info)
+ return info_dict["ocr_info"]
+
+ def _smooth_box(self, bbox, height, width):
+ bbox[0] = int(bbox[0] * 1000.0 / width)
+ bbox[2] = int(bbox[2] * 1000.0 / width)
+ bbox[1] = int(bbox[1] * 1000.0 / height)
+ bbox[3] = int(bbox[3] * 1000.0 / height)
+ return bbox
+
+ def _parse_label(self, label, encode_res):
+ gt_label = []
+ if label.lower() == "other":
+ gt_label.extend([0] * len(encode_res["input_ids"]))
+ else:
+ gt_label.append(self.label2id_map[("b-" + label).upper()])
+ gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
+ (len(encode_res["input_ids"]) - 1))
+ return gt_label
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index c3dfd316f86d88b5c7fd52eb6ae23d22a4dd32eb..f6568affc861acb7e8de195e9c47b39168108723 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -23,7 +23,6 @@ import sys
import six
import cv2
import numpy as np
-import fasttext
class DecodeImage(object):
@@ -136,6 +135,7 @@ class ToCHWImage(object):
class Fasttext(object):
def __init__(self, path="None", **kwargs):
+ import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
@@ -170,17 +170,19 @@ class Resize(object):
def __call__(self, data):
img = data['image']
- text_polys = data['polys']
+ if 'polys' in data:
+ text_polys = data['polys']
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
- new_boxes = []
- for box in text_polys:
- new_box = []
- for cord in box:
- new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
- new_boxes.append(new_box)
+ if 'polys' in data:
+ new_boxes = []
+ for box in text_polys:
+ new_box = []
+ for cord in box:
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
+ new_boxes.append(new_box)
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
data['image'] = img_resize
- data['polys'] = np.array(new_boxes, dtype=np.float32)
return data
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5025e7985198e7ee40d6c92d8e1814eb1797032
--- /dev/null
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -0,0 +1,19 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
+
+__all__ = [
+ 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation'
+]
diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c115661753cd031b16ec34697157e2fcdcf2dec
--- /dev/null
+++ b/ppocr/data/imaug/vqa/token/__init__.py
@@ -0,0 +1,17 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
+from .vqa_token_pad import VQATokenPad
+from .vqa_token_relation import VQAReTokenRelation
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_chunk.py b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb55b4d55b81d5949ed834693e45c3b40c4b762
--- /dev/null
+++ b/ppocr/data/imaug/vqa/token/vqa_token_chunk.py
@@ -0,0 +1,117 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class VQASerTokenChunk(object):
+ def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
+ self.max_seq_len = max_seq_len
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ encoded_inputs_all = []
+ seq_len = len(data['input_ids'])
+ for index in range(0, seq_len, self.max_seq_len):
+ chunk_beg = index
+ chunk_end = min(index + self.max_seq_len, seq_len)
+ encoded_inputs_example = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ encoded_inputs_example[key] = data[key]
+ else:
+ encoded_inputs_example[key] = data[key][chunk_beg:
+ chunk_end]
+ else:
+ encoded_inputs_example[key] = data[key]
+
+ encoded_inputs_all.append(encoded_inputs_example)
+ return encoded_inputs_all[0]
+
+
+class VQAReTokenChunk(object):
+ def __init__(self,
+ max_seq_len=512,
+ entities_labels=None,
+ infer_mode=False,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.entities_labels = {
+ 'HEADER': 0,
+ 'QUESTION': 1,
+ 'ANSWER': 2
+ } if entities_labels is None else entities_labels
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ # prepare data
+ entities = data.pop('entities')
+ relations = data.pop('relations')
+ encoded_inputs_all = []
+ for index in range(0, len(data["input_ids"]), self.max_seq_len):
+ item = {}
+ for key in data:
+ if key in [
+ 'label', 'input_ids', 'labels', 'token_type_ids',
+ 'bbox', 'attention_mask'
+ ]:
+ if self.infer_mode and key == 'labels':
+ item[key] = data[key]
+ else:
+ item[key] = data[key][index:index + self.max_seq_len]
+ else:
+ item[key] = data[key]
+ # select entity in current chunk
+ entities_in_this_span = []
+ global_to_local_map = {} #
+ for entity_id, entity in enumerate(entities):
+ if (index <= entity["start"] < index + self.max_seq_len and
+ index <= entity["end"] < index + self.max_seq_len):
+ entity["start"] = entity["start"] - index
+ entity["end"] = entity["end"] - index
+ global_to_local_map[entity_id] = len(entities_in_this_span)
+ entities_in_this_span.append(entity)
+
+ # select relations in current chunk
+ relations_in_this_span = []
+ for relation in relations:
+ if (index <= relation["start_index"] < index + self.max_seq_len
+ and index <= relation["end_index"] <
+ index + self.max_seq_len):
+ relations_in_this_span.append({
+ "head": global_to_local_map[relation["head"]],
+ "tail": global_to_local_map[relation["tail"]],
+ "start_index": relation["start_index"] - index,
+ "end_index": relation["end_index"] - index,
+ })
+ item.update({
+ "entities": self.reformat(entities_in_this_span),
+ "relations": self.reformat(relations_in_this_span),
+ })
+ item['entities']['label'] = [
+ self.entities_labels[x] for x in item['entities']['label']
+ ]
+ encoded_inputs_all.append(item)
+ return encoded_inputs_all[0]
+
+ def reformat(self, data):
+ new_data = {}
+ for item in data:
+ for k, v in item.items():
+ if k not in new_data:
+ new_data[k] = []
+ new_data[k].append(v)
+ return new_data
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_pad.py b/ppocr/data/imaug/vqa/token/vqa_token_pad.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e5a20f95f0159e5c57072dd86eff0f25cf49eac
--- /dev/null
+++ b/ppocr/data/imaug/vqa/token/vqa_token_pad.py
@@ -0,0 +1,104 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import numpy as np
+
+
+class VQATokenPad(object):
+ def __init__(self,
+ max_seq_len=512,
+ pad_to_max_seq_len=True,
+ return_attention_mask=True,
+ return_token_type_ids=True,
+ truncation_strategy="longest_first",
+ return_overflowing_tokens=False,
+ return_special_tokens_mask=False,
+ infer_mode=False,
+ **kwargs):
+ self.max_seq_len = max_seq_len
+ self.pad_to_max_seq_len = max_seq_len
+ self.return_attention_mask = return_attention_mask
+ self.return_token_type_ids = return_token_type_ids
+ self.truncation_strategy = truncation_strategy
+ self.return_overflowing_tokens = return_overflowing_tokens
+ self.return_special_tokens_mask = return_special_tokens_mask
+ self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ needs_to_be_padded = self.pad_to_max_seq_len and len(data[
+ "input_ids"]) < self.max_seq_len
+
+ if needs_to_be_padded:
+ if 'tokenizer_params' in data:
+ tokenizer_params = data.pop('tokenizer_params')
+ else:
+ tokenizer_params = dict(
+ padding_side='right', pad_token_type_id=0, pad_token_id=1)
+
+ difference = self.max_seq_len - len(data["input_ids"])
+ if tokenizer_params['padding_side'] == 'right':
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data[
+ "input_ids"]) + [0] * difference
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ data["token_type_ids"] +
+ [tokenizer_params['pad_token_type_id']] * difference)
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = data[
+ "special_tokens_mask"] + [1] * difference
+ data["input_ids"] = data["input_ids"] + [
+ tokenizer_params['pad_token_id']
+ ] * difference
+ if not self.infer_mode:
+ data["labels"] = data[
+ "labels"] + [self.pad_token_label_id] * difference
+ data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference
+ elif tokenizer_params['padding_side'] == 'left':
+ if self.return_attention_mask:
+ data["attention_mask"] = [0] * difference + [
+ 1
+ ] * len(data["input_ids"])
+ if self.return_token_type_ids:
+ data["token_type_ids"] = (
+ [tokenizer_params['pad_token_type_id']] * difference +
+ data["token_type_ids"])
+ if self.return_special_tokens_mask:
+ data["special_tokens_mask"] = [
+ 1
+ ] * difference + data["special_tokens_mask"]
+ data["input_ids"] = [tokenizer_params['pad_token_id']
+ ] * difference + data["input_ids"]
+ if not self.infer_mode:
+ data["labels"] = [self.pad_token_label_id
+ ] * difference + data["labels"]
+ data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"]
+ else:
+ if self.return_attention_mask:
+ data["attention_mask"] = [1] * len(data["input_ids"])
+
+ for key in data:
+ if key in [
+ 'input_ids', 'labels', 'token_type_ids', 'bbox',
+ 'attention_mask'
+ ]:
+ if self.infer_mode:
+ if key != 'labels':
+ length = min(len(data[key]), self.max_seq_len)
+ data[key] = data[key][:length]
+ else:
+ continue
+ data[key] = np.array(data[key], dtype='int64')
+ return data
diff --git a/ppocr/data/imaug/vqa/token/vqa_token_relation.py b/ppocr/data/imaug/vqa/token/vqa_token_relation.py
new file mode 100644
index 0000000000000000000000000000000000000000..293988ff85aecb39bac84b412f3466abecc6db4d
--- /dev/null
+++ b/ppocr/data/imaug/vqa/token/vqa_token_relation.py
@@ -0,0 +1,67 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class VQAReTokenRelation(object):
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ """
+ build relations
+ """
+ entities = data['entities']
+ relations = data['relations']
+ id2label = data.pop('id2label')
+ empty_entity = data.pop('empty_entity')
+ entity_id_to_index_map = data.pop('entity_id_to_index_map')
+
+ relations = list(set(relations))
+ relations = [
+ rel for rel in relations
+ if rel[0] not in empty_entity and rel[1] not in empty_entity
+ ]
+ kv_relations = []
+ for rel in relations:
+ pair = [id2label[rel[0]], id2label[rel[1]]]
+ if pair == ["question", "answer"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[0]],
+ "tail": entity_id_to_index_map[rel[1]]
+ })
+ elif pair == ["answer", "question"]:
+ kv_relations.append({
+ "head": entity_id_to_index_map[rel[1]],
+ "tail": entity_id_to_index_map[rel[0]]
+ })
+ else:
+ continue
+ relations = sorted(
+ [{
+ "head": rel["head"],
+ "tail": rel["tail"],
+ "start_index": self.get_relation_span(rel, entities)[0],
+ "end_index": self.get_relation_span(rel, entities)[1],
+ } for rel in kv_relations],
+ key=lambda x: x["head"], )
+
+ data['relations'] = relations
+ return data
+
+ def get_relation_span(self, rel, entities):
+ bound = []
+ for entity_index in [rel["head"], rel["tail"]]:
+ bound.append(entities[entity_index]["start"])
+ bound.append(entities[entity_index]["end"])
+ return min(bound), max(bound)
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index e2d6dc9327bf3725d2fb6c32d18c0b71bd6ac408..e1b49809d199096ad06b90c4562aa5dbfa634db1 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {}
dataset_idx = 0
diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py
index 5adcd02c4a24074c0252a8590fd89f015a6ff152..6f80179c4eb971ace360edb5368f6a2acd5a6322 100644
--- a/ppocr/data/pgnet_dataset.py
+++ b/ppocr/data/pgnet_dataset.py
@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py
index 78b76c5afb8c96bc96730c7b8ad76b4bafa31c67..671cda76fb4c36f3ac6bcc7da5a7fc4de241c0e2 100644
--- a/ppocr/data/pubtab_dataset.py
+++ b/ppocr/data/pubtab_dataset.py
@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
+ ratio_list = dataset_config.get("ratio_list", [1.0])
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset):
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
-
+
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset):
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
- select_flag = False
-
+ select_flag = False
+
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
- data = {'img_path': img_path, 'cells': cells, 'structure':structure}
+ data = {
+ 'img_path': img_path,
+ 'cells': cells,
+ 'structure': structure
+ }
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index ee8571b8c452bbd834fc5dbcf01ce390562163d6..10b6b7a891f99edfac3e824458238848a2ab5b51 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -41,7 +41,6 @@ class SimpleDataSet(Dataset):
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
-
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
@@ -50,6 +49,8 @@ class SimpleDataSet(Dataset):
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
+ self.need_reset = True in [x < 1 for x in ratio_list]
+
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
@@ -69,6 +70,16 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines)
return
+ def _try_parse_filename_list(self, file_name):
+ # multiple images -> one gt label
+ if len(file_name) > 0 and file_name[0] == "[":
+ try:
+ info = json.loads(file_name)
+ file_name = random.choice(info)
+ except:
+ pass
+ return file_name
+
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
@@ -85,6 +96,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
@@ -95,7 +107,7 @@ class SimpleDataSet(Dataset):
data['image'] = img
data = transform(data, load_data_ops)
- if data is None or data['polys'].shape[1]!=4:
+ if data is None or data['polys'].shape[1] != 4:
continue
ext_data.append(data)
return ext_data
@@ -107,6 +119,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
+ file_name = self._try_parse_filename_list(file_name)
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 62ad2b6ad86edf9b5446aea03f9333f9d4981336..56e6d25d4b10bd224e357828c5355ebceef59634 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -16,6 +16,9 @@ import copy
import paddle
import paddle.nn as nn
+# basic_loss
+from .basic_loss import LossFromOutput
+
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
@@ -46,12 +49,16 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
+# vqa token loss
+from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
+
def build_loss(config):
support_dict = [
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
- 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss'
+ 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
+ 'VQASerTokenLayoutLMLoss', 'LossFromOutput'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index d2ef5e5ac9692eec5bc30774c4451eab7706705d..fc64c133a4ad5a97530e2ad259ad38267188f6d3 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -133,3 +133,18 @@ class DistanceLoss(nn.Layer):
def forward(self, x, y):
return self.loss_func(x, y)
+
+
+class LossFromOutput(nn.Layer):
+ def __init__(self, key='loss', reduction='none'):
+ super().__init__()
+ self.key = key
+ self.reduction = reduction
+
+ def forward(self, predicts, batch):
+ loss = predicts[self.key]
+ if self.reduction == 'mean':
+ loss = paddle.mean(loss)
+ elif self.reduction == 'sum':
+ loss = paddle.sum(loss)
+ return {'loss': loss}
diff --git a/ppstructure/vqa/losses.py b/ppocr/losses/vqa_token_layoutlm_loss.py
old mode 100644
new mode 100755
similarity index 66%
rename from ppstructure/vqa/losses.py
rename to ppocr/losses/vqa_token_layoutlm_loss.py
index e8dad01c3198f200788c7898d1b77b38d917d1ca..244893d97d0e422c5ca270bdece689e13aba2b07
--- a/ppstructure/vqa/losses.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -1,10 +1,10 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,24 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
from paddle import nn
-class SERLoss(nn.Layer):
+class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
- def forward(self, labels, outputs, attention_mask):
+ def forward(self, predicts, batch):
+ labels = batch[1]
+ attention_mask = batch[4]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
- active_outputs = outputs.reshape(
+ active_outputs = predicts.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
else:
loss = self.loss_class(
- outputs.reshape([-1, self.num_classes]), labels.reshape([-1, ]))
- return loss
+ predicts.reshape([-1, self.num_classes]),
+ labels.reshape([-1, ]))
+ return {'loss': loss}
diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py
index 28bff3cb4eb7784db876940f761208f1b084f0e2..604ae548df5f54fecdf22de756741da554cec17e 100644
--- a/ppocr/metrics/__init__.py
+++ b/ppocr/metrics/__init__.py
@@ -28,12 +28,15 @@ from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
from .table_metric import TableMetric
from .kie_metric import KIEMetric
+from .vqa_token_ser_metric import VQASerTokenMetric
+from .vqa_token_re_metric import VQAReTokenMetric
def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric",
- "DistillationMetric", "TableMetric", 'KIEMetric'
+ "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
+ 'VQAReTokenMetric'
]
config = copy.deepcopy(config)
diff --git a/ppocr/metrics/cls_metric.py b/ppocr/metrics/cls_metric.py
index 09817200234dc8d8b5d091ebbe33f07f4aad2cf6..6c077518ce205d4ec4d426aaedb8c0af880122ee 100644
--- a/ppocr/metrics/cls_metric.py
+++ b/ppocr/metrics/cls_metric.py
@@ -16,6 +16,7 @@
class ClsMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
+ self.eps = 1e-5
self.reset()
def __call__(self, pred_label, *args, **kwargs):
@@ -28,7 +29,7 @@ class ClsMetric(object):
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {'acc': correct_num / all_num, }
+ return {'acc': correct_num / (all_num + self.eps), }
def get_metric(self):
"""
@@ -36,7 +37,7 @@ class ClsMetric(object):
'acc': 0
}
"""
- acc = self.correct_num / self.all_num
+ acc = self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index b0ccd974f24f1c7e0c9a8e1d414373021c4288e6..b047bbcb972cadf227daaeb8797c46095ac0af43 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -20,6 +20,7 @@ class RecMetric(object):
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator
self.is_filter = is_filter
+ self.eps = 1e-5
self.reset()
def _normalize_text(self, text):
@@ -47,8 +48,8 @@ class RecMetric(object):
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
return {
- 'acc': correct_num / all_num,
- 'norm_edit_dis': 1 - norm_edit_dis / (all_num + 1e-3)
+ 'acc': correct_num / (all_num + self.eps),
+ 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
}
def get_metric(self):
@@ -58,8 +59,8 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
- acc = 1.0 * self.correct_num / (self.all_num + 1e-3)
- norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + 1e-3)
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
+ norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
index 80d1c789ecc3979bd4c33620af91ccd28012f7a8..ca4d6474202b4e85cadf86ccb2fe2726c7fa9aeb 100644
--- a/ppocr/metrics/table_metric.py
+++ b/ppocr/metrics/table_metric.py
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
+
+
class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
self.main_indicator = main_indicator
+ self.eps = 1e-5
self.reset()
def __call__(self, pred, batch, *args, **kwargs):
@@ -31,9 +34,7 @@ class TableMetric(object):
correct_num += 1
self.correct_num += correct_num
self.all_num += all_num
- return {
- 'acc': correct_num * 1.0 / all_num,
- }
+ return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self):
"""
@@ -41,7 +42,7 @@ class TableMetric(object):
'acc': 0,
}
"""
- acc = 1.0 * self.correct_num / self.all_num
+ acc = 1.0 * self.correct_num / (self.all_num + self.eps)
self.reset()
return {'acc': acc}
diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a13bc081298284194d365933cd67d5633957ee8
--- /dev/null
+++ b/ppocr/metrics/vqa_token_re_metric.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class VQAReTokenMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ pred_relations, relations, entities = preds
+ self.pred_relations_list.extend(pred_relations)
+ self.relations_list.extend(relations)
+ self.entities_list.extend(entities)
+
+ def get_metric(self):
+ gt_relations = []
+ for b in range(len(self.relations_list)):
+ rel_sent = []
+ for head, tail in zip(self.relations_list[b]["head"],
+ self.relations_list[b]["tail"]):
+ rel = {}
+ rel["head_id"] = head
+ rel["head"] = (self.entities_list[b]["start"][rel["head_id"]],
+ self.entities_list[b]["end"][rel["head_id"]])
+ rel["head_type"] = self.entities_list[b]["label"][rel[
+ "head_id"]]
+
+ rel["tail_id"] = tail
+ rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
+ self.entities_list[b]["end"][rel["tail_id"]])
+ rel["tail_type"] = self.entities_list[b]["label"][rel[
+ "tail_id"]]
+
+ rel["type"] = 1
+ rel_sent.append(rel)
+ gt_relations.append(rel_sent)
+ re_metrics = self.re_score(
+ self.pred_relations_list, gt_relations, mode="boundaries")
+ metrics = {
+ "precision": re_metrics["ALL"]["p"],
+ "recall": re_metrics["ALL"]["r"],
+ "hmean": re_metrics["ALL"]["f1"],
+ }
+ self.reset()
+ return metrics
+
+ def reset(self):
+ self.pred_relations_list = []
+ self.relations_list = []
+ self.entities_list = []
+
+ def re_score(self, pred_relations, gt_relations, mode="strict"):
+ """Evaluate RE predictions
+
+ Args:
+ pred_relations (list) : list of list of predicted relations (several relations in each sentence)
+ gt_relations (list) : list of list of ground truth relations
+
+ rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
+ "tail": (start_idx (inclusive), end_idx (exclusive)),
+ "head_type": ent_type,
+ "tail_type": ent_type,
+ "type": rel_type}
+
+ vocab (Vocab) : dataset vocabulary
+ mode (str) : in 'strict' or 'boundaries'"""
+
+ assert mode in ["strict", "boundaries"]
+
+ relation_types = [v for v in [0, 1] if not v == 0]
+ scores = {
+ rel: {
+ "tp": 0,
+ "fp": 0,
+ "fn": 0
+ }
+ for rel in relation_types + ["ALL"]
+ }
+
+ # Count GT relations and Predicted relations
+ n_sents = len(gt_relations)
+ n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
+ n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
+
+ # Count TP, FP and FN per type
+ for pred_sent, gt_sent in zip(pred_relations, gt_relations):
+ for rel_type in relation_types:
+ # strict mode takes argument types into account
+ if mode == "strict":
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in pred_sent
+ if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
+ rel["tail_type"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ # boundaries mode only takes argument spans into account
+ elif mode == "boundaries":
+ pred_rels = {(rel["head"], rel["tail"])
+ for rel in pred_sent
+ if rel["type"] == rel_type}
+ gt_rels = {(rel["head"], rel["tail"])
+ for rel in gt_sent if rel["type"] == rel_type}
+
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
+
+ # Compute per entity Precision / Recall / F1
+ for rel_type in scores.keys():
+ if scores[rel_type]["tp"]:
+ scores[rel_type]["p"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fp"] + scores[rel_type]["tp"])
+ scores[rel_type]["r"] = scores[rel_type]["tp"] / (
+ scores[rel_type]["fn"] + scores[rel_type]["tp"])
+ else:
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
+
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
+ scores[rel_type]["f1"] = (
+ 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
+ (scores[rel_type]["p"] + scores[rel_type]["r"]))
+ else:
+ scores[rel_type]["f1"] = 0
+
+ # Compute micro F1 Scores
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
+
+ if tp:
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ f1 = 2 * precision * recall / (precision + recall)
+
+ else:
+ precision, recall, f1 = 0, 0, 0
+
+ scores["ALL"]["p"] = precision
+ scores["ALL"]["r"] = recall
+ scores["ALL"]["f1"] = f1
+ scores["ALL"]["tp"] = tp
+ scores["ALL"]["fp"] = fp
+ scores["ALL"]["fn"] = fn
+
+ # Compute Macro F1 Scores
+ scores["ALL"]["Macro_f1"] = np.mean(
+ [scores[ent_type]["f1"] for ent_type in relation_types])
+ scores["ALL"]["Macro_p"] = np.mean(
+ [scores[ent_type]["p"] for ent_type in relation_types])
+ scores["ALL"]["Macro_r"] = np.mean(
+ [scores[ent_type]["r"] for ent_type in relation_types])
+
+ return scores
diff --git a/ppocr/metrics/vqa_token_ser_metric.py b/ppocr/metrics/vqa_token_ser_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d80d0970dc2eab1d3fb82e2b4cfb8d930a60a0
--- /dev/null
+++ b/ppocr/metrics/vqa_token_ser_metric.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+
+__all__ = ['KIEMetric']
+
+
+class VQASerTokenMetric(object):
+ def __init__(self, main_indicator='hmean', **kwargs):
+ self.main_indicator = main_indicator
+ self.reset()
+
+ def __call__(self, preds, batch, **kwargs):
+ preds, labels = preds
+ self.pred_list.extend(preds)
+ self.gt_list.extend(labels)
+
+ def get_metric(self):
+ from seqeval.metrics import f1_score, precision_score, recall_score
+ metircs = {
+ "precision": precision_score(self.gt_list, self.pred_list),
+ "recall": recall_score(self.gt_list, self.pred_list),
+ "hmean": f1_score(self.gt_list, self.pred_list),
+ }
+ self.reset()
+ return metircs
+
+ def reset(self):
+ self.pred_list = []
+ self.gt_list = []
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index c498d9862abcfc85eaf29ed1d949230a1dc1629c..e622db25677069f9a4470db4966b7523def35472 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -63,8 +63,12 @@ class BaseModel(nn.Layer):
in_channels = self.neck.out_channels
# # build head, head is need for det, rec and cls
- config["Head"]['in_channels'] = in_channels
- self.head = build_head(config["Head"])
+ if 'Head' not in config or config['Head'] is None:
+ self.use_head = False
+ else:
+ self.use_head = True
+ config["Head"]['in_channels'] = in_channels
+ self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
@@ -77,7 +81,8 @@ class BaseModel(nn.Layer):
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
- x = self.head(x, targets=data)
+ if self.use_head:
+ x = self.head(x, targets=data)
if isinstance(x, dict):
y.update(x)
else:
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index d10983487bedb0fc4278095db08d1f234ef5c595..a7db52d26704e0c8426e313b8788b656085983d6 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -29,9 +29,10 @@ def build_backbone(config, model_type):
from .rec_nrtr_mtb import MTB
from .rec_resnet_31 import ResNet31
from .rec_resnet_aster import ResNet_ASTER
+ from .rec_micronet import MicroNet
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
- "ResNet31", "ResNet_ASTER"
+ "ResNet31", "ResNet_ASTER", 'MicroNet'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
@@ -43,6 +44,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"]
+ elif model_type == 'vqa':
+ from .vqa_layoutlm import LayoutLMForSer, LayoutXLMForSer, LayoutXLMForRe
+ support_dict = ["LayoutLMForSer", "LayoutXLMForSer", 'LayoutXLMForRe']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/rec_micronet.py b/ppocr/modeling/backbones/rec_micronet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ae5a14c3004f63d39dff32cc737a0d96155593
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_micronet.py
@@ -0,0 +1,528 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This code is refer from:
+https://github.com/liyunsheng13/micronet/blob/main/backbone/micronet.py
+https://github.com/liyunsheng13/micronet/blob/main/backbone/activation.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+from ppocr.modeling.backbones.det_mobilenet_v3 import make_divisible
+
+M0_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r
+ [2, 1, 8, 3, 2, 2, 0, 4, 8, 2, 2, 2, 0, 1, 1],
+ [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 2, 1, 1],
+ [2, 1, 16, 5, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 4, 4, 4, 32, 4, 4, 2, 2, 1, 1],
+ [2, 1, 64, 5, 1, 4, 8, 8, 64, 8, 8, 2, 2, 1, 1],
+ [1, 1, 96, 3, 1, 4, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 384, 3, 1, 4, 12, 12, 0, 0, 0, 2, 2, 1, 2],
+]
+M1_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 8, 3, 2, 2, 0, 6, 8, 2, 2, 2, 0, 1, 1],
+ [2, 1, 16, 3, 2, 2, 0, 8, 16, 4, 4, 2, 2, 1, 1],
+ [2, 1, 16, 5, 2, 2, 0, 16, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 6, 4, 4, 32, 4, 4, 2, 2, 1, 1],
+ [2, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 1],
+ [1, 1, 96, 3, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 576, 3, 1, 6, 12, 12, 0, 0, 0, 2, 2, 1, 2],
+]
+M2_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 12, 3, 2, 2, 0, 8, 12, 4, 4, 2, 0, 1, 1],
+ [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 2, 2, 1, 1],
+ [1, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 2, 2, 1, 1],
+ [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 2, 2, 1, 1],
+ [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 2, 2, 1, 2],
+ [1, 1, 64, 5, 1, 6, 8, 8, 64, 8, 8, 2, 2, 1, 2],
+ [2, 1, 96, 5, 1, 6, 8, 8, 96, 8, 8, 2, 2, 1, 2],
+ [1, 1, 128, 3, 1, 6, 12, 12, 128, 8, 8, 2, 2, 1, 2],
+ [1, 1, 768, 3, 1, 6, 16, 16, 0, 0, 0, 2, 2, 1, 2],
+]
+M3_cfgs = [
+ # s, n, c, ks, c1, c2, g1, g2, c3, g3, g4
+ [2, 1, 16, 3, 2, 2, 0, 12, 16, 4, 4, 0, 2, 0, 1],
+ [2, 1, 24, 3, 2, 2, 0, 16, 24, 4, 4, 0, 2, 0, 1],
+ [1, 1, 24, 3, 2, 2, 0, 24, 24, 4, 4, 0, 2, 0, 1],
+ [2, 1, 32, 5, 1, 6, 6, 6, 32, 4, 4, 0, 2, 0, 1],
+ [1, 1, 32, 5, 1, 6, 8, 8, 32, 4, 4, 0, 2, 0, 2],
+ [1, 1, 64, 5, 1, 6, 8, 8, 48, 8, 8, 0, 2, 0, 2],
+ [1, 1, 80, 5, 1, 6, 8, 8, 80, 8, 8, 0, 2, 0, 2],
+ [1, 1, 80, 5, 1, 6, 10, 10, 80, 8, 8, 0, 2, 0, 2],
+ [1, 1, 120, 5, 1, 6, 10, 10, 120, 10, 10, 0, 2, 0, 2],
+ [1, 1, 120, 5, 1, 6, 12, 12, 120, 10, 10, 0, 2, 0, 2],
+ [1, 1, 144, 3, 1, 6, 12, 12, 144, 12, 12, 0, 2, 0, 2],
+ [1, 1, 432, 3, 1, 3, 12, 12, 0, 0, 0, 0, 2, 0, 2],
+]
+
+
+def get_micronet_config(mode):
+ return eval(mode + '_cfgs')
+
+
+class MaxGroupPooling(nn.Layer):
+ def __init__(self, channel_per_group=2):
+ super(MaxGroupPooling, self).__init__()
+ self.channel_per_group = channel_per_group
+
+ def forward(self, x):
+ if self.channel_per_group == 1:
+ return x
+ # max op
+ b, c, h, w = x.shape
+
+ # reshape
+ y = paddle.reshape(x, [b, c // self.channel_per_group, -1, h, w])
+ out = paddle.max(y, axis=2)
+ return out
+
+
+class SpatialSepConvSF(nn.Layer):
+ def __init__(self, inp, oups, kernel_size, stride):
+ super(SpatialSepConvSF, self).__init__()
+
+ oup1, oup2 = oups
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ oup1, (kernel_size, 1), (stride, 1), (kernel_size // 2, 0),
+ bias_attr=False,
+ groups=1),
+ nn.BatchNorm2D(oup1),
+ nn.Conv2D(
+ oup1,
+ oup1 * oup2, (1, kernel_size), (1, stride),
+ (0, kernel_size // 2),
+ bias_attr=False,
+ groups=oup1),
+ nn.BatchNorm2D(oup1 * oup2),
+ ChannelShuffle(oup1), )
+
+ def forward(self, x):
+ out = self.conv(x)
+ return out
+
+
+class ChannelShuffle(nn.Layer):
+ def __init__(self, groups):
+ super(ChannelShuffle, self).__init__()
+ self.groups = groups
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+
+ channels_per_group = c // self.groups
+
+ # reshape
+ x = paddle.reshape(x, [b, self.groups, channels_per_group, h, w])
+
+ x = paddle.transpose(x, (0, 2, 1, 3, 4))
+ out = paddle.reshape(x, [b, -1, h, w])
+
+ return out
+
+
+class StemLayer(nn.Layer):
+ def __init__(self, inp, oup, stride, groups=(4, 4)):
+ super(StemLayer, self).__init__()
+
+ g1, g2 = groups
+ self.stem = nn.Sequential(
+ SpatialSepConvSF(inp, groups, 3, stride),
+ MaxGroupPooling(2) if g1 * g2 == 2 * oup else nn.ReLU6())
+
+ def forward(self, x):
+ out = self.stem(x)
+ return out
+
+
+class DepthSpatialSepConv(nn.Layer):
+ def __init__(self, inp, expand, kernel_size, stride):
+ super(DepthSpatialSepConv, self).__init__()
+
+ exp1, exp2 = expand
+
+ hidden_dim = inp * exp1
+ oup = inp * exp1 * exp2
+
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ inp * exp1, (kernel_size, 1), (stride, 1),
+ (kernel_size // 2, 0),
+ bias_attr=False,
+ groups=inp),
+ nn.BatchNorm2D(inp * exp1),
+ nn.Conv2D(
+ hidden_dim,
+ oup, (1, kernel_size),
+ 1, (0, kernel_size // 2),
+ bias_attr=False,
+ groups=hidden_dim),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class GroupConv(nn.Layer):
+ def __init__(self, inp, oup, groups=2):
+ super(GroupConv, self).__init__()
+ self.inp = inp
+ self.oup = oup
+ self.groups = groups
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp, oup, 1, 1, 0, bias_attr=False, groups=self.groups[0]),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class DepthConv(nn.Layer):
+ def __init__(self, inp, oup, kernel_size, stride):
+ super(DepthConv, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2D(
+ inp,
+ oup,
+ kernel_size,
+ stride,
+ kernel_size // 2,
+ bias_attr=False,
+ groups=inp),
+ nn.BatchNorm2D(oup))
+
+ def forward(self, x):
+ out = self.conv(x)
+ return out
+
+
+class DYShiftMax(nn.Layer):
+ def __init__(self,
+ inp,
+ oup,
+ reduction=4,
+ act_max=1.0,
+ act_relu=True,
+ init_a=[0.0, 0.0],
+ init_b=[0.0, 0.0],
+ relu_before_pool=False,
+ g=None,
+ expansion=False):
+ super(DYShiftMax, self).__init__()
+ self.oup = oup
+ self.act_max = act_max * 2
+ self.act_relu = act_relu
+ self.avg_pool = nn.Sequential(nn.ReLU() if relu_before_pool == True else
+ nn.Sequential(), nn.AdaptiveAvgPool2D(1))
+
+ self.exp = 4 if act_relu else 2
+ self.init_a = init_a
+ self.init_b = init_b
+
+ # determine squeeze
+ squeeze = make_divisible(inp // reduction, 4)
+ if squeeze < 4:
+ squeeze = 4
+
+ self.fc = nn.Sequential(
+ nn.Linear(inp, squeeze),
+ nn.ReLU(), nn.Linear(squeeze, oup * self.exp), nn.Hardsigmoid())
+
+ if g is None:
+ g = 1
+ self.g = g[1]
+ if self.g != 1 and expansion:
+ self.g = inp // self.g
+
+ self.gc = inp // self.g
+ index = paddle.to_tensor([range(inp)])
+ index = paddle.reshape(index, [1, inp, 1, 1])
+ index = paddle.reshape(index, [1, self.g, self.gc, 1, 1])
+ indexgs = paddle.split(index, [1, self.g - 1], axis=1)
+ indexgs = paddle.concat((indexgs[1], indexgs[0]), axis=1)
+ indexs = paddle.split(indexgs, [1, self.gc - 1], axis=2)
+ indexs = paddle.concat((indexs[1], indexs[0]), axis=2)
+ self.index = paddle.reshape(indexs, [inp])
+ self.expansion = expansion
+
+ def forward(self, x):
+ x_in = x
+ x_out = x
+
+ b, c, _, _ = x_in.shape
+ y = self.avg_pool(x_in)
+ y = paddle.reshape(y, [b, c])
+ y = self.fc(y)
+ y = paddle.reshape(y, [b, self.oup * self.exp, 1, 1])
+ y = (y - 0.5) * self.act_max
+
+ n2, c2, h2, w2 = x_out.shape
+ x2 = paddle.to_tensor(x_out.numpy()[:, self.index.numpy(), :, :])
+
+ if self.exp == 4:
+ temp = y.shape
+ a1, b1, a2, b2 = paddle.split(y, temp[1] // self.oup, axis=1)
+
+ a1 = a1 + self.init_a[0]
+ a2 = a2 + self.init_a[1]
+
+ b1 = b1 + self.init_b[0]
+ b2 = b2 + self.init_b[1]
+
+ z1 = x_out * a1 + x2 * b1
+ z2 = x_out * a2 + x2 * b2
+
+ out = paddle.maximum(z1, z2)
+
+ elif self.exp == 2:
+ temp = y.shape
+ a1, b1 = paddle.split(y, temp[1] // self.oup, axis=1)
+ a1 = a1 + self.init_a[0]
+ b1 = b1 + self.init_b[0]
+ out = x_out * a1 + x2 * b1
+
+ return out
+
+
+class DYMicroBlock(nn.Layer):
+ def __init__(self,
+ inp,
+ oup,
+ kernel_size=3,
+ stride=1,
+ ch_exp=(2, 2),
+ ch_per_group=4,
+ groups_1x1=(1, 1),
+ depthsep=True,
+ shuffle=False,
+ activation_cfg=None):
+ super(DYMicroBlock, self).__init__()
+
+ self.identity = stride == 1 and inp == oup
+
+ y1, y2, y3 = activation_cfg['dy']
+ act_reduction = 8 * activation_cfg['ratio']
+ init_a = activation_cfg['init_a']
+ init_b = activation_cfg['init_b']
+
+ t1 = ch_exp
+ gs1 = ch_per_group
+ hidden_fft, g1, g2 = groups_1x1
+ hidden_dim2 = inp * t1[0] * t1[1]
+
+ if gs1[0] == 0:
+ self.layers = nn.Sequential(
+ DepthSpatialSepConv(inp, t1, kernel_size, stride),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y2 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=False) if y2 > 0 else nn.ReLU6(),
+ ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
+ ChannelShuffle(hidden_dim2 // 2)
+ if shuffle and y2 != 0 else nn.Sequential(),
+ GroupConv(hidden_dim2, oup, (g1, g2)),
+ DYShiftMax(
+ oup,
+ oup,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction // 2,
+ init_b=[0.0, 0.0],
+ g=(g1, g2),
+ expansion=False) if y3 > 0 else nn.Sequential(),
+ ChannelShuffle(g2) if shuffle else nn.Sequential(),
+ ChannelShuffle(oup // 2)
+ if shuffle and oup % 2 == 0 and y3 != 0 else nn.Sequential(), )
+ elif g2 == 0:
+ self.layers = nn.Sequential(
+ GroupConv(inp, hidden_dim2, gs1),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction,
+ init_b=[0.0, 0.0],
+ g=gs1,
+ expansion=False) if y3 > 0 else nn.Sequential(), )
+ else:
+ self.layers = nn.Sequential(
+ GroupConv(inp, hidden_dim2, gs1),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y1 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=False) if y1 > 0 else nn.ReLU6(),
+ ChannelShuffle(gs1[1]) if shuffle else nn.Sequential(),
+ DepthSpatialSepConv(hidden_dim2, (1, 1), kernel_size, stride)
+ if depthsep else
+ DepthConv(hidden_dim2, hidden_dim2, kernel_size, stride),
+ nn.Sequential(),
+ DYShiftMax(
+ hidden_dim2,
+ hidden_dim2,
+ act_max=2.0,
+ act_relu=True if y2 == 2 else False,
+ init_a=init_a,
+ reduction=act_reduction,
+ init_b=init_b,
+ g=gs1,
+ expansion=True) if y2 > 0 else nn.ReLU6(),
+ ChannelShuffle(hidden_dim2 // 4)
+ if shuffle and y1 != 0 and y2 != 0 else nn.Sequential()
+ if y1 == 0 and y2 == 0 else ChannelShuffle(hidden_dim2 // 2),
+ GroupConv(hidden_dim2, oup, (g1, g2)),
+ DYShiftMax(
+ oup,
+ oup,
+ act_max=2.0,
+ act_relu=False,
+ init_a=[1.0, 0.0],
+ reduction=act_reduction // 2
+ if oup < hidden_dim2 else act_reduction,
+ init_b=[0.0, 0.0],
+ g=(g1, g2),
+ expansion=False) if y3 > 0 else nn.Sequential(),
+ ChannelShuffle(g2) if shuffle else nn.Sequential(),
+ ChannelShuffle(oup // 2)
+ if shuffle and y3 != 0 else nn.Sequential(), )
+
+ def forward(self, x):
+ identity = x
+ out = self.layers(x)
+
+ if self.identity:
+ out = out + identity
+
+ return out
+
+
+class MicroNet(nn.Layer):
+ """
+ the MicroNet backbone network for recognition module.
+ Args:
+ mode(str): {'M0', 'M1', 'M2', 'M3'}
+ Four models are proposed based on four different computational costs (4M, 6M, 12M, 21M MAdds)
+ Default: 'M3'.
+ """
+
+ def __init__(self, mode='M3', **kwargs):
+ super(MicroNet, self).__init__()
+
+ self.cfgs = get_micronet_config(mode)
+
+ activation_cfg = {}
+ if mode == 'M0':
+ input_channel = 4
+ stem_groups = 2, 2
+ out_ch = 384
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M1':
+ input_channel = 6
+ stem_groups = 3, 2
+ out_ch = 576
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M2':
+ input_channel = 8
+ stem_groups = 4, 2
+ out_ch = 768
+ activation_cfg['init_a'] = 1.0, 1.0
+ activation_cfg['init_b'] = 0.0, 0.0
+ elif mode == 'M3':
+ input_channel = 12
+ stem_groups = 4, 3
+ out_ch = 432
+ activation_cfg['init_a'] = 1.0, 0.5
+ activation_cfg['init_b'] = 0.0, 0.5
+ else:
+ raise NotImplementedError("mode[" + mode +
+ "_model] is not implemented!")
+
+ layers = [StemLayer(3, input_channel, stride=2, groups=stem_groups)]
+
+ for idx, val in enumerate(self.cfgs):
+ s, n, c, ks, c1, c2, g1, g2, c3, g3, g4, y1, y2, y3, r = val
+
+ t1 = (c1, c2)
+ gs1 = (g1, g2)
+ gs2 = (c3, g3, g4)
+ activation_cfg['dy'] = [y1, y2, y3]
+ activation_cfg['ratio'] = r
+
+ output_channel = c
+ layers.append(
+ DYMicroBlock(
+ input_channel,
+ output_channel,
+ kernel_size=ks,
+ stride=s,
+ ch_exp=t1,
+ ch_per_group=gs1,
+ groups_1x1=gs2,
+ depthsep=True,
+ shuffle=True,
+ activation_cfg=activation_cfg, ))
+ input_channel = output_channel
+ for i in range(1, n):
+ layers.append(
+ DYMicroBlock(
+ input_channel,
+ output_channel,
+ kernel_size=ks,
+ stride=1,
+ ch_exp=t1,
+ ch_per_group=gs1,
+ groups_1x1=gs2,
+ depthsep=True,
+ shuffle=True,
+ activation_cfg=activation_cfg, ))
+ input_channel = output_channel
+ self.features = nn.Sequential(*layers)
+
+ self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
+
+ self.out_channels = make_divisible(out_ch)
+
+ def forward(self, x):
+ x = self.features(x)
+ x = self.pool(x)
+ return x
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e98155514cdd055680f32b529fdce631384a37f
--- /dev/null
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -0,0 +1,125 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from paddle import nn
+
+from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
+from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
+
+__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
+
+pretrained_model_dict = {
+ LayoutXLMModel: 'layoutxlm-base-uncased',
+ LayoutLMModel: 'layoutlm-base-uncased'
+}
+
+
+class NLPBaseModel(nn.Layer):
+ def __init__(self,
+ base_model_class,
+ model_class,
+ type='ser',
+ pretrained=True,
+ checkpoints=None,
+ **kwargs):
+ super(NLPBaseModel, self).__init__()
+ if checkpoints is not None:
+ self.model = model_class.from_pretrained(checkpoints)
+ else:
+ pretrained_model_name = pretrained_model_dict[base_model_class]
+ if pretrained:
+ base_model = base_model_class.from_pretrained(
+ pretrained_model_name)
+ else:
+ base_model = base_model_class(
+ **base_model_class.pretrained_init_configuration[
+ pretrained_model_name])
+ if type == 'ser':
+ self.model = model_class(
+ base_model, num_classes=kwargs['num_classes'], dropout=None)
+ else:
+ self.model = model_class(base_model, dropout=None)
+ self.out_channels = 1
+
+
+class LayoutXLMForSer(NLPBaseModel):
+ def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ **kwargs):
+ super(LayoutXLMForSer, self).__init__(
+ LayoutXLMModel,
+ LayoutXLMForTokenClassification,
+ 'ser',
+ pretrained,
+ checkpoints,
+ num_classes=num_classes)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[2],
+ image=x[3],
+ attention_mask=x[4],
+ token_type_ids=x[5],
+ position_ids=None,
+ head_mask=None,
+ labels=None)
+ return x[0]
+
+
+class LayoutLMForSer(NLPBaseModel):
+ def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ **kwargs):
+ super(LayoutLMForSer, self).__init__(
+ LayoutLMModel,
+ LayoutLMForTokenClassification,
+ 'ser',
+ pretrained,
+ checkpoints,
+ num_classes=num_classes)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[2],
+ attention_mask=x[4],
+ token_type_ids=x[5],
+ position_ids=None,
+ output_hidden_states=False)
+ return x
+
+
+class LayoutXLMForRe(NLPBaseModel):
+ def __init__(self, pretrained=True, checkpoints=None, **kwargs):
+ super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
+ LayoutXLMForRelationExtraction,
+ 're', pretrained, checkpoints)
+
+ def forward(self, x):
+ x = self.model(
+ input_ids=x[0],
+ bbox=x[1],
+ labels=None,
+ image=x[2],
+ attention_mask=x[3],
+ token_type_ids=x[4],
+ position_ids=None,
+ head_mask=None,
+ entities=x[5],
+ relations=x[6])
+ return x
diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py
index c729103a700a59764bda4f53dd68d3958172ca57..e0c6b90371cb4b09fb894ceeaeb8595e51c6c557 100644
--- a/ppocr/optimizer/__init__.py
+++ b/ppocr/optimizer/__init__.py
@@ -42,7 +42,9 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
reg_config = config.pop('regularizer')
- reg_name = reg_config.pop('name') + 'Decay'
+ reg_name = reg_config.pop('name')
+ if not hasattr(regularizer, reg_name):
+ reg_name += 'Decay'
reg = getattr(regularizer, reg_name)(**reg_config)()
else:
reg = None
diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py
index e1b10992676cfdf73fb7573e5289c133981d1474..b1879f3ee509761043c1797d8b67e4e0988af130 100644
--- a/ppocr/optimizer/learning_rate.py
+++ b/ppocr/optimizer/learning_rate.py
@@ -18,7 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from paddle.optimizer import lr
-from .lr_scheduler import CyclicalCosineDecay
+from .lr_scheduler import CyclicalCosineDecay, OneCycleDecay
class Linear(object):
@@ -226,3 +226,53 @@ class CyclicalCosine(object):
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
+
+
+class OneCycle(object):
+ """
+ One Cycle learning rate decay
+ Args:
+ max_lr(float): Upper learning rate boundaries
+ epochs(int): total training epochs
+ step_each_epoch(int): steps each epoch
+ anneal_strategy(str): {‘cos’, ‘linear’} Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing.
+ Default: ‘cos’
+ three_phase(bool): If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’
+ instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).
+ last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
+ """
+
+ def __init__(self,
+ max_lr,
+ epochs,
+ step_each_epoch,
+ anneal_strategy='cos',
+ three_phase=False,
+ warmup_epoch=0,
+ last_epoch=-1,
+ **kwargs):
+ super(OneCycle, self).__init__()
+ self.max_lr = max_lr
+ self.epochs = epochs
+ self.steps_per_epoch = step_each_epoch
+ self.anneal_strategy = anneal_strategy
+ self.three_phase = three_phase
+ self.last_epoch = last_epoch
+ self.warmup_epoch = round(warmup_epoch * step_each_epoch)
+
+ def __call__(self):
+ learning_rate = OneCycleDecay(
+ max_lr=self.max_lr,
+ epochs=self.epochs,
+ steps_per_epoch=self.steps_per_epoch,
+ anneal_strategy=self.anneal_strategy,
+ three_phase=self.three_phase,
+ last_epoch=self.last_epoch)
+ if self.warmup_epoch > 0:
+ learning_rate = lr.LinearWarmup(
+ learning_rate=learning_rate,
+ warmup_steps=self.warmup_epoch,
+ start_lr=0.0,
+ end_lr=self.max_lr,
+ last_epoch=self.last_epoch)
+ return learning_rate
\ No newline at end of file
diff --git a/ppocr/optimizer/lr_scheduler.py b/ppocr/optimizer/lr_scheduler.py
index 21aec737d0005e3dcd814ad7eff88988ab2c0796..f62f1f3b0adbd8df0e03a66faa4565f2f7df28bc 100644
--- a/ppocr/optimizer/lr_scheduler.py
+++ b/ppocr/optimizer/lr_scheduler.py
@@ -47,3 +47,116 @@ class CyclicalCosineDecay(LRScheduler):
lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * \
(1 + math.cos(math.pi * reletive_epoch / self.cycle))
return lr
+
+
+class OneCycleDecay(LRScheduler):
+ """
+ One Cycle learning rate decay
+ A learning rate which can be referred in https://arxiv.org/abs/1708.07120
+ Code refered in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
+ """
+
+ def __init__(self,
+ max_lr,
+ epochs=None,
+ steps_per_epoch=None,
+ pct_start=0.3,
+ anneal_strategy='cos',
+ div_factor=25.,
+ final_div_factor=1e4,
+ three_phase=False,
+ last_epoch=-1,
+ verbose=False):
+
+ # Validate total_steps
+ if epochs <= 0 or not isinstance(epochs, int):
+ raise ValueError(
+ "Expected positive integer epochs, but got {}".format(epochs))
+ if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
+ raise ValueError(
+ "Expected positive integer steps_per_epoch, but got {}".format(
+ steps_per_epoch))
+ self.total_steps = epochs * steps_per_epoch
+
+ self.max_lr = max_lr
+ self.initial_lr = self.max_lr / div_factor
+ self.min_lr = self.initial_lr / final_div_factor
+
+ if three_phase:
+ self._schedule_phases = [
+ {
+ 'end_step': float(pct_start * self.total_steps) - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.max_lr,
+ },
+ {
+ 'end_step': float(2 * pct_start * self.total_steps) - 2,
+ 'start_lr': self.max_lr,
+ 'end_lr': self.initial_lr,
+ },
+ {
+ 'end_step': self.total_steps - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.min_lr,
+ },
+ ]
+ else:
+ self._schedule_phases = [
+ {
+ 'end_step': float(pct_start * self.total_steps) - 1,
+ 'start_lr': self.initial_lr,
+ 'end_lr': self.max_lr,
+ },
+ {
+ 'end_step': self.total_steps - 1,
+ 'start_lr': self.max_lr,
+ 'end_lr': self.min_lr,
+ },
+ ]
+
+ # Validate pct_start
+ if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+ raise ValueError(
+ "Expected float between 0 and 1 pct_start, but got {}".format(
+ pct_start))
+
+ # Validate anneal_strategy
+ if anneal_strategy not in ['cos', 'linear']:
+ raise ValueError(
+ "anneal_strategy must by one of 'cos' or 'linear', instead got {}".
+ format(anneal_strategy))
+ elif anneal_strategy == 'cos':
+ self.anneal_func = self._annealing_cos
+ elif anneal_strategy == 'linear':
+ self.anneal_func = self._annealing_linear
+
+ super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
+
+ def _annealing_cos(self, start, end, pct):
+ "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ cos_out = math.cos(math.pi * pct) + 1
+ return end + (start - end) / 2.0 * cos_out
+
+ def _annealing_linear(self, start, end, pct):
+ "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ return (end - start) * pct + start
+
+ def get_lr(self):
+ computed_lr = 0.0
+ step_num = self.last_epoch
+
+ if step_num > self.total_steps:
+ raise ValueError(
+ "Tried to step {} times. The specified number of total steps is {}"
+ .format(step_num + 1, self.total_steps))
+ start_step = 0
+ for i, phase in enumerate(self._schedule_phases):
+ end_step = phase['end_step']
+ if step_num <= end_step or i == len(self._schedule_phases) - 1:
+ pct = (step_num - start_step) / (end_step - start_step)
+ computed_lr = self.anneal_func(phase['start_lr'],
+ phase['end_lr'], pct)
+ break
+ start_step = phase['end_step']
+
+ return computed_lr
diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py
index 34098c0fad553f7d39f6b5341e4da70a263eeaea..b98081227e180edbf023a8b5b7a0b82bb7c631e5 100644
--- a/ppocr/optimizer/optimizer.py
+++ b/ppocr/optimizer/optimizer.py
@@ -158,3 +158,38 @@ class Adadelta(object):
name=self.name,
parameters=parameters)
return opt
+
+
+class AdamW(object):
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-08,
+ weight_decay=0.01,
+ grad_clip=None,
+ name=None,
+ lazy_mode=False,
+ **kwargs):
+ self.learning_rate = learning_rate
+ self.beta1 = beta1
+ self.beta2 = beta2
+ self.epsilon = epsilon
+ self.learning_rate = learning_rate
+ self.weight_decay = 0.01 if weight_decay is None else weight_decay
+ self.grad_clip = grad_clip
+ self.name = name
+ self.lazy_mode = lazy_mode
+
+ def __call__(self, parameters):
+ opt = optim.AdamW(
+ learning_rate=self.learning_rate,
+ beta1=self.beta1,
+ beta2=self.beta2,
+ epsilon=self.epsilon,
+ weight_decay=self.weight_decay,
+ grad_clip=self.grad_clip,
+ name=self.name,
+ lazy_mode=self.lazy_mode,
+ parameters=parameters)
+ return opt
diff --git a/ppocr/optimizer/regularizer.py b/ppocr/optimizer/regularizer.py
index c6396f338d9d40fc444083e205fd55329e7dfd59..2ce68f7139e21f9e3e1dcc155254b7a92b0e7270 100644
--- a/ppocr/optimizer/regularizer.py
+++ b/ppocr/optimizer/regularizer.py
@@ -29,24 +29,23 @@ class L1Decay(object):
def __init__(self, factor=0.0):
super(L1Decay, self).__init__()
- self.regularization_coeff = factor
+ self.coeff = factor
def __call__(self):
- reg = paddle.regularizer.L1Decay(self.regularization_coeff)
+ reg = paddle.regularizer.L1Decay(self.coeff)
return reg
class L2Decay(object):
"""
- L2 Weight Decay Regularization, which encourages the weights to be sparse.
+ L2 Weight Decay Regularization, which helps to prevent the model over-fitting.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L2Decay, self).__init__()
- self.regularization_coeff = factor
+ self.coeff = float(factor)
def __call__(self):
- reg = paddle.regularizer.L2Decay(self.regularization_coeff)
- return reg
+ return self.coeff
\ No newline at end of file
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 37dadd12d3f628b1802b6a31f611f49f3ac600c2..811bf57b6435530b8b1361cc7e0c8acd4ba3a724 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -28,6 +28,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, NRTRLabelDecode, SARLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
+from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
+from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
def build_post_process(config, global_config=None):
@@ -36,7 +38,8 @@ def build_post_process(config, global_config=None):
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
- 'SEEDLabelDecode'
+ 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
+ 'VQAReTokenLayoutLMPostProcess'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d55d13d76b496ba0a5b540ba915889ce9146a8e
--- /dev/null
+++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+
+
+class VQAReTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, **kwargs):
+ super(VQAReTokenLayoutLMPostProcess, self).__init__()
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if label is not None:
+ return self._metric(preds, label)
+ else:
+ return self._infer(preds, *args, **kwargs)
+
+ def _metric(self, preds, label):
+ return preds['pred_relations'], label[6], label[5]
+
+ def _infer(self, preds, *args, **kwargs):
+ ser_results = kwargs['ser_results']
+ entity_idx_dict_batch = kwargs['entity_idx_dict_batch']
+ pred_relations = preds['pred_relations']
+
+ # merge relations and ocr info
+ results = []
+ for pred_relation, ser_result, entity_idx_dict in zip(
+ pred_relations, ser_results, entity_idx_dict_batch):
+ result = []
+ used_tail_id = []
+ for relation in pred_relation:
+ if relation['tail_id'] in used_tail_id:
+ continue
+ used_tail_id.append(relation['tail_id'])
+ ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
+ ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
+ result.append((ocr_info_head, ocr_info_tail))
+ results.append(result)
+ return results
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..782cdea6c58c69e0d728787e0e21e200c9e13790
--- /dev/null
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -0,0 +1,93 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import paddle
+from ppocr.utils.utility import load_vqa_bio_label_maps
+
+
+class VQASerTokenLayoutLMPostProcess(object):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, class_path, **kwargs):
+ super(VQASerTokenLayoutLMPostProcess, self).__init__()
+ label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
+
+ self.label2id_map_for_draw = dict()
+ for key in label2id_map:
+ if key.startswith("I-"):
+ self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
+ else:
+ self.label2id_map_for_draw[key] = label2id_map[key]
+
+ self.id2label_map_for_show = dict()
+ for key in self.label2id_map_for_draw:
+ val = self.label2id_map_for_draw[key]
+ if key == "O":
+ self.id2label_map_for_show[val] = key
+ if key.startswith("B-") or key.startswith("I-"):
+ self.id2label_map_for_show[val] = key[2:]
+ else:
+ self.id2label_map_for_show[val] = key
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+
+ if batch is not None:
+ return self._metric(preds, batch[1])
+ else:
+ return self._infer(preds, **kwargs)
+
+ def _metric(self, preds, label):
+ pred_idxs = preds.argmax(axis=2)
+ decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+ label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
+
+ for i in range(pred_idxs.shape[0]):
+ for j in range(pred_idxs.shape[1]):
+ if label[i, j] != -100:
+ label_decode_out_list[i].append(self.id2label_map[label[i,
+ j]])
+ decode_out_list[i].append(self.id2label_map[pred_idxs[i,
+ j]])
+ return decode_out_list, label_decode_out_list
+
+ def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
+ results = []
+
+ for pred, attention_mask, segment_offset_id, ocr_info in zip(
+ preds, attention_masks, segment_offset_ids, ocr_infos):
+ pred = np.argmax(pred, axis=1)
+ pred = [self.id2label_map[idx] for idx in pred]
+
+ for idx in range(len(segment_offset_id)):
+ if idx == 0:
+ start_id = 0
+ else:
+ start_id = segment_offset_id[idx - 1]
+
+ end_id = segment_offset_id[idx]
+
+ curr_pred = pred[start_id:end_id]
+ curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
+
+ if len(curr_pred) <= 0:
+ pred_id = 0
+ else:
+ counts = np.bincount(curr_pred)
+ pred_id = np.argmax(counts)
+ ocr_info[idx]["pred_id"] = int(pred_id)
+ ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
+ results.append(ocr_info)
+ return results
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index f6013a406634ed110ea5af613a5f31e56ce90ead..b09f1db6e938e8eb99148d69efce016f1cbe8628 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
-def load_model(config, model, optimizer=None):
+def load_model(config, model, optimizer=None, model_type='det'):
"""
load model from checkpoint or pretrained_model
"""
@@ -53,6 +53,33 @@ def load_model(config, model, optimizer=None):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
+
+ if model_type == 'vqa':
+ checkpoints = config['Architecture']['Backbone']['checkpoints']
+ # load vqa method metric
+ if checkpoints:
+ if os.path.exists(os.path.join(checkpoints, 'metric.states')):
+ with open(os.path.join(checkpoints, 'metric.states'),
+ 'rb') as f:
+ states_dict = pickle.load(f) if six.PY2 else pickle.load(
+ f, encoding='latin1')
+ best_model_dict = states_dict.get('best_model_dict', {})
+ if 'epoch' in states_dict:
+ best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ logger.info("resume from {}".format(checkpoints))
+
+ if optimizer is not None:
+ if checkpoints[-1] in ['/', '\\']:
+ checkpoints = checkpoints[:-1]
+ if os.path.exists(checkpoints + '.pdopt'):
+ optim_dict = paddle.load(checkpoints + '.pdopt')
+ optimizer.set_state_dict(optim_dict)
+ else:
+ logger.warning(
+ "{}.pdopt is not exists, params of optimizer is not loaded".
+ format(checkpoints))
+ return best_model_dict
+
if checkpoints:
if checkpoints.endswith('.pdparams'):
checkpoints = checkpoints.replace('.pdparams', '')
@@ -111,13 +138,16 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
- for k1, k2 in zip(state_dict.keys(), params.keys()):
- if list(state_dict[k1].shape) == list(params[k2].shape):
- new_state_dict[k1] = params[k2]
+ for k1 in params.keys():
+ if k1 not in state_dict.keys():
+ logger.warning("The pretrained params {} not in model".format(k1))
else:
- logger.warning(
- "The shape of model params {} {} not matched with loaded params {} {} !".
- format(k1, state_dict[k1].shape, k2, params[k2].shape))
+ if list(state_dict[k1].shape) == list(params[k1].shape):
+ new_state_dict[k1] = params[k1]
+ else:
+ logger.warning(
+ "The shape of model params {} {} not matched with loaded params {} {} !".
+ format(k1, state_dict[k1].shape, k1, params[k1].shape))
model.set_state_dict(new_state_dict)
logger.info("load pretrain successful from {}".format(path))
return model
@@ -127,6 +157,7 @@ def save_model(model,
optimizer,
model_path,
logger,
+ config,
is_best=False,
prefix='ppocr',
**kwargs):
@@ -135,13 +166,20 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
-
+ if config['Architecture']["model_type"] != 'vqa':
+ paddle.save(model.state_dict(), model_prefix + '.pdparams')
+ metric_prefix = model_prefix
+ else:
+ if config['Global']['distributed']:
+ model._layers.backbone.model.save_pretrained(model_prefix)
+ else:
+ model.backbone.model.save_pretrained(model_prefix)
+ metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
- with open(model_prefix + '.states', 'wb') as f:
- pickle.dump(kwargs, f, protocol=2)
if is_best:
+ with open(metric_prefix + '.states', 'wb') as f:
+ pickle.dump(kwargs, f, protocol=2)
logger.info('save best model is to {}'.format(model_prefix))
else:
logger.info("save model in {}".format(model_prefix))
diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py
index 7bb4c906d298af54ed56e2805f487a2c22d1894b..76484dfd3d3caaa03731368cf4eace1715121874 100755
--- a/ppocr/utils/utility.py
+++ b/ppocr/utils/utility.py
@@ -16,6 +16,9 @@ import logging
import os
import imghdr
import cv2
+import random
+import numpy as np
+import paddle
def print_dict(d, logger, delimiter=0):
@@ -77,4 +80,28 @@ def check_and_read_gif(img_path):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True
- return None, False
\ No newline at end of file
+ return None, False
+
+
+def load_vqa_bio_label_maps(label_map_path):
+ with open(label_map_path, "r", encoding='utf-8') as fin:
+ lines = fin.readlines()
+ lines = [line.strip() for line in lines]
+ if "O" not in lines:
+ lines.insert(0, "O")
+ labels = []
+ for line in lines:
+ if line == "O":
+ labels.append("O")
+ else:
+ labels.append("B-" + line)
+ labels.append("I-" + line)
+ label2id_map = {label: idx for idx, label in enumerate(labels)}
+ id2label_map = {idx: label for idx, label in enumerate(labels)}
+ return label2id_map, id2label_map
+
+
+def set_seed(seed=1024):
+ random.seed(seed)
+ np.random.seed(seed)
+ paddle.seed(seed)
diff --git a/ppocr/utils/visual.py b/ppocr/utils/visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8c1674a74f89299de59f7cd120b4577a7499d8
--- /dev/null
+++ b/ppocr/utils/visual.py
@@ -0,0 +1,98 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+
+def draw_ser_results(image,
+ ocr_results,
+ font_path="doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(2021)
+ color = (np.random.permutation(range(255)),
+ np.random.permutation(range(255)),
+ np.random.permutation(range(255)))
+ color_map = {
+ idx: (color[0][idx], color[1][idx], color[2][idx])
+ for idx in range(1, 255)
+ }
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ for ocr_info in ocr_results:
+ if ocr_info["pred_id"] not in color_map:
+ continue
+ color = color_map[ocr_info["pred_id"]]
+ text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
+
+ draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
+
+
+def draw_box_txt(bbox, text, draw, font, font_size, color):
+ # draw ocr results outline
+ bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
+ draw.rectangle(bbox, fill=color)
+
+ # draw ocr results
+ start_y = max(0, bbox[0][1] - font_size)
+ tw = font.getsize(text)[0]
+ draw.rectangle(
+ [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
+ fill=(0, 0, 255))
+ draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
+
+
+def draw_re_results(image,
+ result,
+ font_path="doc/fonts/simfang.ttf",
+ font_size=18):
+ np.random.seed(0)
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ elif isinstance(image, str) and os.path.isfile(image):
+ image = Image.open(image).convert('RGB')
+ img_new = image.copy()
+ draw = ImageDraw.Draw(img_new)
+
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
+ color_head = (0, 0, 255)
+ color_tail = (255, 0, 0)
+ color_line = (0, 255, 0)
+
+ for ocr_info_head, ocr_info_tail in result:
+ draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
+ font_size, color_head)
+ draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
+ font_size, color_tail)
+
+ center_head = (
+ (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
+ (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
+ center_tail = (
+ (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
+ (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
+
+ draw.line([center_head, center_tail], fill=color_line, width=5)
+
+ img_new = Image.blend(image, img_new, 0.5)
+ return np.array(img_new)
diff --git a/ppstructure/README.md b/ppstructure/README.md
index a09a43299b11dccf99897d5a6c69704191253aaf..1d201a7c6e54f6ed71be6d1872b7f4b226ad35ad 100644
--- a/ppstructure/README.md
+++ b/ppstructure/README.md
@@ -1,187 +1,140 @@
English | [简体中文](README_ch.md)
-# PP-Structure
+- [1. Introduction](#1)
+- [2. Update log](#2)
+- [3. Features](#3)
+- [4. Results](#4)
+ * [4.1 Layout analysis and table recognition](#41)
+ * [4.2 DOC-VQA](#42)
+- [5. Quick start](#5)
+- [6. PP-Structure System](#6)
+ * [6.1 Layout analysis and table recognition](#61)
+ * [6.2 DOC-VQA](#62)
+- [7. Model List](#7)
-PP-Structure is an OCR toolkit that can be used for complex documents analysis. The main features are as follows:
-- Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
-- Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
-- Support to extract excel files from the table areas
-- Support python whl package and command line usage, easy to use
-- Support custom training for layout analysis and table structure tasks
+
-## 1. Visualization
-
-
+## 1. Introduction
+PP-Structure is an OCR toolkit that can be used for document analysis and processing with complex structures, designed to help developers better complete document understanding tasks
+
-## 2. Installation
+## 2. Update log
+* 2021.12.07 add [DOC-VQA SER and RE tasks](vqa/README.md)。
-### 2.1 Install requirements
+
-- **(1) Install PaddlePaddle**
+## 3. Features
-```bash
-pip3 install --upgrade pip
+The main features of PP-Structure are as follows:
-# GPU
-python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
-
-# CPU
- python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
+- Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
+- Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
+- Support to extract excel files from the table areas
+- Support python whl package and command line usage, easy to use
+- Support custom training for layout analysis and table structure tasks
+- Support Document Visual Question Answering (DOC-VQA) tasks: Semantic Entity Recognition (SER) and Relation Extraction (RE)
-```
-For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) .
-- **(2) Install Layout-Parser**
+
-```bash
-pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
-```
+## 4. Results
-### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
+
-- **(1) PIP install PaddleOCR whl package(inference only)**
+### 4.1 Layout analysis and table recognition
-```bash
-pip install "paddleocr>=2.2"
-```
+
-- **(2) Clone PaddleOCR(Inference+training)**
+The figure shows the pipeline of layout analysis + table recognition. The image is first divided into four areas of image, text, title and table by layout analysis, and then OCR detection and recognition is performed on the three areas of image, text and title, and the table is performed table recognition, where the image will also be stored for use.
-```bash
-git clone https://github.com/PaddlePaddle/PaddleOCR
-```
+
+### 4.2 DOC-VQA
-## 3. Quick Start
+* SER
-### 3.1 Use by command line
+![](./vqa/images/result_ser/zh_val_0_ser.jpg) | ![](./vqa/images/result_ser/zh_val_42_ser.jpg)
+---|---
-```bash
-paddleocr --image_dir=../doc/table/1.png --type=structure
-```
+Different colored boxes in the figure represent different categories. For xfun dataset, there are three categories: query, answer and header:
-### 3.2 Use by python API
+* Dark purple: header
+* Light purple: query
+* Army green: answer
-```python
-import os
-import cv2
-from paddleocr import PPStructure,draw_structure_result,save_structure_res
+The corresponding category and OCR recognition results are also marked at the top left of the OCR detection box.
-table_engine = PPStructure(show_log=True)
-save_folder = './output/table'
-img_path = '../doc/table/1.png'
-img = cv2.imread(img_path)
-result = table_engine(img)
-save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+* RE
-for line in result:
- line.pop('img')
- print(line)
+![](./vqa/images/result_re/zh_val_21_re.jpg) | ![](./vqa/images/result_re/zh_val_40_re.jpg)
+---|---
-from PIL import Image
-font_path = '../doc/fonts/simfang.ttf'
-image = Image.open(img_path).convert('RGB')
-im_show = draw_structure_result(image, result,font_path=font_path)
-im_show = Image.fromarray(im_show)
-im_show.save('result.jpg')
-```
-### 3.3 Returned results format
-The returned results of PP-Structure is a list composed of a dict, an example is as follows
+In the figure, the red box represents the question, the blue box represents the answer, and the question and answer are connected by green lines. The corresponding category and OCR recognition results are also marked at the top left of the OCR detection box.
-```shell
-[
- { 'type': 'Text',
- 'bbox': [34, 432, 345, 462],
- 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
- [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
- }
-]
-```
-The description of each field in dict is as follows
-| Parameter | Description |
-| --------------- | -------------|
-|type|Type of image area|
-|bbox|The coordinates of the image area in the original image, respectively [left upper x, left upper y, right bottom x, right bottom y]|
-|res|OCR or table recognition result of image area。
Table: HTML string of the table;
OCR: A tuple containing the detection coordinates and recognition results of each single line of text|
+
+## 5. Quick start
-### 3.4 Parameter description:
+Start from [Quick Installation](./docs/quickstart.md)
-| Parameter | Description | Default value |
-| --------------- | ---------------------------------------- | ------------------------------------------- |
-| output | The path where excel and recognition results are saved | ./output/table |
-| table_max_len | The long side of the image is resized in table structure model | 488 |
-| table_model_dir | inference model path of table structure model | None |
-| table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
+
-Most of the parameters are consistent with the paddleocr whl package, see [doc of whl](../doc/doc_en/whl_en.md)
+## 6. PP-Structure System
-After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
+
-## 4. PP-Structure Pipeline
-![pipeline](../doc/table/pipeline_en.jpg)
+### 6.1 Layout analysis and table recognition
-In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR.
+![pipeline](../doc/table/pipeline.jpg)
-### 4.1 LayoutParser
+In PP-Structure, the image will be divided into 5 types of areas **text, title, image list and table**. For the first 4 types of areas, directly use PP-OCR system to complete the text detection and recognition. For the table area, after the table structuring process, the table in image is converted into an Excel file with the same table style.
-Layout analysis divides the document data into regions, including the use of Python scripts for layout analysis tools, extraction of special category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md).
+#### 6.1.1 Layout analysis
-### 4.2 Table Recognition
+Layout analysis classifies image by region, including the use of Python scripts of layout analysis tools, extraction of designated category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md).
-Table Recognition converts table image into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed, please refer to [document](table/README.md)
+#### 6.1.2 Table recognition
-## 5. Prediction by inference engine
+Table recognition converts table images into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed instructions, please refer to [document](table/README.md)
-Use the following commands to complete the inference.
+
-```python
-cd PaddleOCR/ppstructure
+### 6.2 DOC-VQA
-# download model
-mkdir inference && cd inference
-# Download the detection model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
-# Download the recognition model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-# Download the table structure model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
-cd ..
+Document Visual Question Answering (DOC-VQA) if a type of Visual Question Answering (VQA), which includes Semantic Entity Recognition (SER) and Relation Extraction (RE) tasks. Based on SER task, text recognition and classification in images can be completed. Based on THE RE task, we can extract the relation of the text content in the image, such as judge the problem pair. For details, please refer to [document](vqa/README.md)
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
-```
-After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
-**Model List**
+
-|model name|description|config|model size|download|
-| --- | --- | --- | --- | --- |
-|en_ppocr_mobile_v2.0_table_structure|Table structure prediction for English table scenarios|[table_mv3.yml](../configs/table/table_mv3.yml)|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+## 7. Model List
-**Model List**
+PP-Structure系列模型列表(更新中)
-LayoutParser model
+* Layout analysis model
|model name|description|download|
| --- | --- | --- |
-| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet data set can be divided into 5 types of areas **text, title, table, picture and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset can only detect tables | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
-| ppyolov2_r50vd_dcn_365e_tableBank_latex | The layout analysis model trained on the TableBank Latex dataset can only detect tables | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
+| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet dataset can divide image into 5 types of areas **text, title, table, picture, and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
-OCR and table recognition model
+
+* OCR and table recognition model
|model name|description|model size|download|
| --- | --- | --- | --- |
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
-|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
-|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
-|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
+
+* DOC-VQA model
+|model name|description|model size|download|
+| --- | --- | --- | --- |
+|PP-Layout_v1.0_ser_pretrained|SER model trained on xfun Chinese dataset based on LayoutXLM|1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
+|PP-Layout_v1.0_re_pretrained|RE model trained on xfun Chinese dataset based on LayoutXLM|1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
-If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
+If you need to use other models, you can download the model in [PPOCR model_list](../doc/doc_en/models_list_en.md) and [PPStructure model_list](./docs/model_list.md)
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index 607efac1bf6bfaa58f0e96ceef1a0ee344189e9c..808a5c68d18df625bedeae4706da7f985d6caecd 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -1,14 +1,32 @@
[English](README.md) | 简体中文
-## 简介
+- [1. 简介](#1)
+- [2. 近期更新](#2)
+- [3. 特性](#3)
+- [4. 效果展示](#4)
+ * [4.1 版面分析和表格识别](#41)
+ * [4.2 DOC-VQA](#42)
+- [5. 快速体验](#5)
+- [6. PP-Structure 介绍](#6)
+ * [6.1 版面分析+表格识别](#61)
+ * [6.2 DOC-VQA](#62)
+- [7. 模型库](#7)
+
+
+
+## 1. 简介
PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,旨在帮助开发者更好的完成文档理解相关任务。
-## 近期更新
-* 2021.12.07 新增VQA任务-SER和RE。
+
-## 特性
+## 2. 近期更新
+* 2021.12.07 新增DOC-[VQA任务SER和RE](vqa/README.md)。
-PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包,主要特性如下:
+
+
+## 3. 特性
+
+PP-Structure的主要特性如下:
- 支持对图片形式的文档进行版面分析,可以划分**文字、标题、表格、图片以及列表**5类区域(与Layout-Parser联合使用)
- 支持文字、标题、图片以及列表区域提取为文字字段(与PP-OCR联合使用)
- 支持表格区域进行结构化分析,最终结果输出Excel文件
@@ -17,13 +35,22 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
- 支持文档视觉问答(Document Visual Question Answering,DOC-VQA)任务-语义实体识别(Semantic Entity Recognition,SER)和关系抽取(Relation Extraction,RE)
-## 1. 效果展示
+
+
+## 4. 效果展示
-### 1.1 版面分析和表格识别
+
+
+### 4.1 版面分析和表格识别
-### 1.2 VQA
+图中展示了版面分析+表格识别的整体流程,图片先有版面分析划分为图像、文本、标题和表格四种区域,然后对图像、文本和标题三种区域进行OCR的检测识别,对表格进行表格识别,其中图像还会被存储下来以便使用。
+
+
+
+
+### 4.2 DOC-VQA
* SER
@@ -46,36 +73,45 @@ PP-Structure是一个可用于复杂文档结构分析和处理的OCR工具包
图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
-## 2. 快速体验
+
+
+## 5. 快速体验
+
+请参考[快速安装](./docs/quickstart.md)教程。
-代码体验:从 [快速安装](./docs/quickstart.md) 开始
+
-## 3. PP-Structure Pipeline介绍
+## 6. PP-Structure 介绍
-### 3.1 版面分析+表格识别
+
+
+### 6.1 版面分析+表格识别
![pipeline](../doc/table/pipeline.jpg)
在PP-Structure中,图片会先经由Layout-Parser进行版面分析,在版面分析中,会对图片里的区域进行分类,包括**文字、标题、图片、列表和表格**5类。对于前4类区域,直接使用PP-OCR完成对应区域文字检测与识别。对于表格类区域,经过表格结构化处理后,表格图片转换为相同表格样式的Excel文件。
-#### 3.1.1 版面分析
+#### 6.1.1 版面分析
版面分析对文档数据进行区域分类,其中包括版面分析工具的Python脚本使用、提取指定类别检测框、性能指标以及自定义训练版面分析模型,详细内容可以参考[文档](layout/README_ch.md)。
-#### 3.1.2 表格识别
+#### 6.1.2 表格识别
+
+表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)。
-表格识别将表格图片转换为excel文档,其中包含对于表格文本的检测和识别以及对于表格结构和单元格坐标的预测,详细说明参考[文档](table/README_ch.md)
+
+### 6.2 DOC-VQA
-### 3.2 VQA
+DOC-VQA指文档视觉问答,其中包括语义实体识别 (Semantic Entity Recognition, SER) 和关系抽取 (Relation Extraction, RE) 任务。基于 SER 任务,可以完成对图像中的文本识别与分类;基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair),详细说明参考[文档](vqa/README.md)。
-coming soon
+
-## 4. 模型库
+## 7. 模型库
PP-Structure系列模型列表(更新中)
-* LayoutParser 模型
+* 版面分析模型
|模型名称|模型简介|下载地址|
| --- | --- | --- |
@@ -90,7 +126,7 @@ PP-Structure系列模型列表(更新中)
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
-* VQA模型
+* DOC-VQA 模型
|模型名称|模型简介|模型大小|下载地址|
| --- | --- | --- | --- |
@@ -98,4 +134,4 @@ PP-Structure系列模型列表(更新中)
|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
-更多模型下载,可以参考 [模型库](./docs/model_list.md)
+更多模型下载,可以参考 [PPOCR model_list](../doc/doc_en/models_list.md) and [PPStructure model_list](./docs/model_list.md)
\ No newline at end of file
diff --git a/ppstructure/docs/kie_en.md b/ppstructure/docs/kie_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..a424968a9b5a33132afe52a4850cfe541919ae1c
--- /dev/null
+++ b/ppstructure/docs/kie_en.md
@@ -0,0 +1,77 @@
+
+
+# Key Information Extraction(KIE)
+
+This section provides a tutorial example on how to quickly use, train, and evaluate a key information extraction(KIE) model, [SDMGR](https://arxiv.org/abs/2103.14470), in PaddleOCR.
+
+[SDMGR(Spatial Dual-Modality Graph Reasoning)](https://arxiv.org/abs/2103.14470) is a KIE algorithm that classifies each detected text region into predefined categories, such as order ID, invoice number, amount, and etc.
+
+
+* [1. Quick Use](#1-----)
+* [2. Model Training](#2-----)
+* [3. Model Evaluation](#3-----)
+
+
+
+## 1. Quick Use
+
+[Wildreceipt dataset](https://paperswithcode.com/dataset/wildreceipt) is used for this tutorial. It contains 1765 photos, with 25 classes, and 50000 text boxes, which can be downloaded by wget:
+
+```shell
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/wildreceipt.tar && tar xf wildreceipt.tar
+```
+
+Download the pretrained model and predict the result:
+
+```shell
+cd PaddleOCR/
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.1/kie/kie_vgg16.tar && tar xf kie_vgg16.tar
+python3.7 tools/infer_kie.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=kie_vgg16/best_accuracy Global.infer_img=../wildreceipt/1.txt
+```
+
+The prediction result is saved as `./output/sdmgr_kie/predicts_kie.txt`, and the visualization results are saved in the folder`/output/sdmgr_kie/kie_results/`.
+
+The visualization results are shown in the figure below:
+
+
+
+
+
+
+## 2. Model Training
+
+Create a softlink to the folder, `PaddleOCR/train_data`:
+```shell
+cd PaddleOCR/ && mkdir train_data && cd train_data
+
+ln -s ../../wildreceipt ./
+```
+
+The configuration file used for training is `configs/kie/kie_unet_sdmgr.yml`. The default training data path in the configuration file is `train_data/wildreceipt`. After preparing the data, you can execute the model training with the following command:
+```shell
+python3.7 tools/train.py -c configs/kie/kie_unet_sdmgr.yml -o Global.save_model_dir=./output/kie/
+```
+
+
+## 3. Model Evaluation
+
+After training, you can execute the model evaluation with the following command:
+
+```shell
+python3.7 tools/eval.py -c configs/kie/kie_unet_sdmgr.yml -o Global.checkpoints=./output/kie/best_accuracy
+```
+
+**Reference:**
+
+
+
+```bibtex
+@misc{sun2021spatial,
+ title={Spatial Dual-Modality Graph Reasoning for Key Information Extraction},
+ author={Hongbin Sun and Zhanghui Kuang and Xiaoyu Yue and Chenhao Lin and Wayne Zhang},
+ year={2021},
+ eprint={2103.14470},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/ppstructure/docs/model_list.md b/ppstructure/docs/model_list.md
index 45004490c1c4b0ea01a5fb409024f1eeb922f1a3..baec2a2fd08a5b8d51e4c68bc62902feb04de977 100644
--- a/ppstructure/docs/model_list.md
+++ b/ppstructure/docs/model_list.md
@@ -24,8 +24,8 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
-|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
+|PP-Layout_v1.0_ser_pretrained|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|PP-Layout_v1.0_re_pretrained|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
## 3. KIE模型
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
index 446c577ec39cf24dd4b8699558c633a1308fa444..668775c6da2b06d973f69a9ce81a37396460cbdf 100644
--- a/ppstructure/docs/quickstart.md
+++ b/ppstructure/docs/quickstart.md
@@ -39,7 +39,7 @@ paddleocr --image_dir=../doc/table/1.png --type=structure
* VQA
-coming soon
+请参考:[文档视觉问答](../vqa/README.md)。
@@ -74,7 +74,7 @@ im_show.save('result.jpg')
* VQA
-comming soon
+请参考:[文档视觉问答](../vqa/README.md)。
@@ -101,7 +101,7 @@ dict 里各个字段说明如下
* VQA
-comming soon
+请参考:[文档视觉问答](../vqa/README.md)。
@@ -116,9 +116,9 @@ comming soon
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
-| mode | pipeline预测模式,structure: 版面分析+表格识别; vqa: ser文档信息抽取 | structure |
+| mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
-大部分参数和paddleocr whl包保持一致,见 [whl包文档](../doc/doc_ch/whl.md)
+大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。
@@ -133,16 +133,16 @@ cd ppstructure
# 下载模型
mkdir inference && cd inference
-# 下载超轻量级中文OCR模型的检测模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
-# 下载超轻量级中文OCR模型的识别模型并解压
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-# 下载超轻量级英文表格英寸模型并解压
+# 下载PP-OCRv2文本检测模型并解压
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
+# 下载PP-OCRv2文本识别模型并解压
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar
+# 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
-python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
- --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
+python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
+ --rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index e87499ccc410ae67a170f63301e5a99ef948b161..3f3dc65875a20b3f66403afecfd60f04e3d83d61 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -30,7 +30,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
-from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
from ppstructure.utility import parse_args, draw_structure_result
logger = get_logger()
@@ -66,6 +65,7 @@ class OCRSystem(object):
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
elif self.mode == 'vqa':
+ from ppstructure.vqa.infer_ser_e2e import SerPredictor, draw_ser_results
self.vqa_engine = SerPredictor(args)
def __call__(self, img):
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index 30a11a20e5de90500d1408f671ba914f336a0b43..94fa76055b93cefab0ac507a6007ec148aa12945 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 33276b36e4973e83d7efa673b90013cf5727dfe2..ef0f1ae5c4554e69e4cbeb0fcd783e6d98f96a41 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -56,7 +56,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_dict_path=../ppocr/utils/dict/en_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 4cf2432f40979e17bef8d8f631a963e641a02591..7f4ca119f70592e59e4a8ed946bddd589b348b97 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -20,11 +20,11 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
我们在 [XFUN](https://github.com/doc-analysis/XFUND) 的中文数据集上对算法进行了评估,性能如下
-| 模型 | 任务 | f1 | 模型下载地址 |
+| 模型 | 任务 | hmean | 模型下载地址 |
|:---:|:---:|:---:| :---:|
-| LayoutXLM | RE | 0.7113 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_re_pretrained.tar) |
-| LayoutXLM | SER | 0.9056 | [链接](https://paddleocr.bj.bcebos.com/pplayout/PP-Layout_v1.0_ser_pretrained.tar) |
-| LayoutLM | SER | 0.78 | [链接](https://paddleocr.bj.bcebos.com/pplayout/LayoutLM_ser_pretrained.tar) |
+| LayoutXLM | RE | 0.7483 | [链接](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+| LayoutXLM | SER | 0.9038 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+| LayoutLM | SER | 0.7731 | [链接](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
@@ -34,7 +34,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
### 2.1 SER
-![](./images/result_ser/zh_val_0_ser.jpg) | ![](./images/result_ser/zh_val_42_ser.jpg)
+![](../../doc/vqa/result_ser/zh_val_0_ser.jpg) | ![](../../doc/vqa/result_ser/zh_val_42_ser.jpg)
---|---
图中不同颜色的框表示不同的类别,对于XFUN数据集,有`QUESTION`, `ANSWER`, `HEADER` 3种类别
@@ -48,7 +48,7 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
### 2.2 RE
-![](./images/result_re/zh_val_21_re.jpg) | ![](./images/result_re/zh_val_40_re.jpg)
+![](../../doc/vqa/result_re/zh_val_21_re.jpg) | ![](../../doc/vqa/result_re/zh_val_40_re.jpg)
---|---
@@ -62,13 +62,13 @@ PP-Structure 里的 DOC-VQA算法基于PaddleNLP自然语言处理算法库进
- **(1) 安装PaddlePaddle**
```bash
-pip3 install --upgrade pip
+python3 -m pip install --upgrade pip
# GPU安装
-python3 -m pip install paddlepaddle-gpu==2.2 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装
-python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
@@ -79,7 +79,7 @@ python3 -m pip install paddlepaddle==2.2 -i https://mirror.baidu.com/pypi/simple
- **(1)pip快速安装PaddleOCR whl包(仅预测)**
```bash
-pip install paddleocr
+python3 -m pip install paddleocr
```
- **(2)下载VQA源码(预测+训练)**
@@ -93,21 +93,10 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
# 注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
```
-- **(3)安装PaddleNLP**
+- **(3)安装VQA的`requirements`**
```bash
-# 需要使用PaddleNLP最新的代码版本进行安装
-git clone https://github.com/PaddlePaddle/PaddleNLP -b develop
-cd PaddleNLP
-pip3 install -e .
-```
-
-
-- **(4)安装VQA的`requirements`**
-
-```bash
-cd ppstructure/vqa
-pip install -r requirements.txt
+python3 -m pip install -r ppstructure/vqa/requirements.txt
```
## 4. 使用
@@ -115,6 +104,10 @@ pip install -r requirements.txt
### 4.1 数据和预训练模型准备
+如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
+
+* 下载处理好的数据集
+
处理好的XFUN中文数据集下载地址:[https://paddleocr.bj.bcebos.com/dataset/XFUND.tar](https://paddleocr.bj.bcebos.com/dataset/XFUND.tar)。
@@ -124,101 +117,65 @@ pip install -r requirements.txt
wget https://paddleocr.bj.bcebos.com/dataset/XFUND.tar
```
-如果希望转换XFUN中其他语言的数据集,可以参考[XFUN数据转换脚本](helper/trans_xfun_data.py)。
+* 转换数据集
-如果希望直接体验预测过程,可以下载我们提供的预训练模型,跳过训练过程,直接预测即可。
+若需进行其他XFUN数据集的训练,可使用下面的命令进行数据集的转换
+```bash
+python3 ppstructure/vqa/helper/trans_xfun_data.py --ori_gt_path=path/to/json_path --output_path=path/to/save_path
+```
### 4.2 SER任务
-* 启动训练
+启动训练之前,需要修改下面的四个字段
+
+1. `Train.dataset.data_dir`:指向训练集图片存放目录
+2. `Train.dataset.label_file_list`:指向训练集标注文件
+3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
+4. `Eval.dataset.label_file_list`:指向验证集标注文件
+* 启动训练
```shell
-python3.7 train_ser.py \
- --model_name_or_path "layoutxlm-base-uncased" \
- --ser_model_type "LayoutXLM" \
- --train_data_dir "XFUND/zh_train/image" \
- --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --num_train_epochs 200 \
- --eval_steps 10 \
- --output_dir "./output/ser/" \
- --learning_rate 5e-5 \
- --warmup_steps 50 \
- --evaluate_during_training \
- --seed 2048
+CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml
```
-最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/ser/`文件夹中。
+最终会打印出`precision`, `recall`, `hmean`等指标。
+在`./output/ser_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
* 恢复训练
+恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
+
```shell
-python3.7 train_ser.py \
- --model_name_or_path "model_path" \
- --ser_model_type "LayoutXLM" \
- --train_data_dir "XFUND/zh_train/image" \
- --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --num_train_epochs 200 \
- --eval_steps 10 \
- --output_dir "./output/ser/" \
- --learning_rate 5e-5 \
- --warmup_steps 50 \
- --evaluate_during_training \
- --num_workers 8 \
- --seed 2048 \
- --resume
+CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
* 评估
-```shell
-export CUDA_VISIBLE_DEVICES=0
-python3 eval_ser.py \
- --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
- --ser_model_type "LayoutXLM" \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --per_gpu_eval_batch_size 8 \
- --num_workers 8 \
- --output_dir "output/ser/" \
- --seed 2048
-```
-最终会打印出`precision`, `recall`, `f1`等指标
-* 使用评估集合中提供的OCR识别结果进行预测
+评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
-export CUDA_VISIBLE_DEVICES=0
-python3.7 infer_ser.py \
- --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
- --ser_model_type "LayoutXLM" \
- --output_dir "output/ser/" \
- --infer_imgs "XFUND/zh_val/image/" \
- --ocr_json_path "XFUND/zh_val/xfun_normalize_val.json"
+CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
+最终会打印出`precision`, `recall`, `hmean`等指标
-最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
+* 使用`OCR引擎 + SER`串联预测
-* 使用`OCR引擎 + SER`串联结果
+使用如下命令即可完成`OCR引擎 + SER`的串联预测
```shell
-export CUDA_VISIBLE_DEVICES=0
-python3.7 infer_ser_e2e.py \
- --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
- --ser_model_type "LayoutXLM" \
- --max_seq_length 512 \
- --output_dir "output/ser_e2e/" \
- --infer_imgs "images/input/zh_val_0.jpg"
+CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
+最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
+
* 对`OCR引擎 + SER`预测系统进行端到端评估
+首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
+
```shell
export CUDA_VISIBLE_DEVICES=0
-python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
+python3 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
@@ -226,102 +183,48 @@ python3.7 helper/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_nor
* 启动训练
-```shell
-export CUDA_VISIBLE_DEVICES=0
-python3 train_re.py \
- --model_name_or_path "layoutxlm-base-uncased" \
- --train_data_dir "XFUND/zh_train/image" \
- --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --label_map_path "labels/labels_ser.txt" \
- --num_train_epochs 200 \
- --eval_steps 10 \
- --output_dir "output/re/" \
- --learning_rate 5e-5 \
- --warmup_steps 50 \
- --per_gpu_train_batch_size 8 \
- --per_gpu_eval_batch_size 8 \
- --num_workers 8 \
- --evaluate_during_training \
- --seed 2048
-
-```
+启动训练之前,需要修改下面的四个字段
-* 恢复训练
+1. `Train.dataset.data_dir`:指向训练集图片存放目录
+2. `Train.dataset.label_file_list`:指向训练集标注文件
+3. `Eval.dataset.data_dir`:指指向验证集图片存放目录
+4. `Eval.dataset.label_file_list`:指向验证集标注文件
```shell
-export CUDA_VISIBLE_DEVICES=0
-python3 train_re.py \
- --model_name_or_path "model_path" \
- --train_data_dir "XFUND/zh_train/image" \
- --train_label_path "XFUND/zh_train/xfun_normalize_train.json" \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --label_map_path "labels/labels_ser.txt" \
- --num_train_epochs 2 \
- --eval_steps 10 \
- --output_dir "output/re/" \
- --learning_rate 5e-5 \
- --warmup_steps 50 \
- --per_gpu_train_batch_size 8 \
- --per_gpu_eval_batch_size 8 \
- --num_workers 8 \
- --evaluate_during_training \
- --seed 2048 \
- --resume
-
+CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml
```
-最终会打印出`precision`, `recall`, `f1`等指标,模型和训练日志会保存在`./output/re/`文件夹中。
+最终会打印出`precision`, `recall`, `hmean`等指标。
+在`./output/re_layoutxlm/`文件夹中会保存训练日志,最优的模型和最新epoch的模型。
+
+* 恢复训练
+
+恢复训练需要将之前训练好的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
-* 评估
```shell
-export CUDA_VISIBLE_DEVICES=0
-python3 eval_re.py \
- --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
- --max_seq_length 512 \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --label_map_path "labels/labels_ser.txt" \
- --output_dir "output/re/" \
- --per_gpu_eval_batch_size 8 \
- --num_workers 8 \
- --seed 2048
+CUDA_VISIBLE_DEVICES=0 python3 tools/train.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
-最终会打印出`precision`, `recall`, `f1`等指标
+* 评估
-* 使用评估集合中提供的OCR识别结果进行预测
+评估需要将待评估的模型所在文件夹路径赋值给 `Architecture.Backbone.checkpoints` 字段。
```shell
-export CUDA_VISIBLE_DEVICES=0
-python3 infer_re.py \
- --model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
- --max_seq_length 512 \
- --eval_data_dir "XFUND/zh_val/image" \
- --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
- --label_map_path "labels/labels_ser.txt" \
- --output_dir "output/re/" \
- --per_gpu_eval_batch_size 1 \
- --seed 2048
+CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=path/to/model_dir
```
+最终会打印出`precision`, `recall`, `hmean`等指标
-最终会在`output_res`目录下保存预测结果可视化图像以及预测结果文本文件,文件名为`infer_results.txt`。
-
-* 使用`OCR引擎 + SER + RE`串联结果
+* 使用`OCR引擎 + SER + RE`串联预测
+使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测
```shell
export CUDA_VISIBLE_DEVICES=0
-python3.7 infer_ser_re_e2e.py \
- --model_name_or_path "PP-Layout_v1.0_ser_pretrained/" \
- --re_model_name_or_path "PP-Layout_v1.0_re_pretrained/" \
- --ser_model_type "LayoutXLM" \
- --max_seq_length 512 \
- --output_dir "output/ser_re_e2e/" \
- --infer_imgs "images/input/zh_val_21.jpg"
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=PP-Layout_v1.0_re_pretrained/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=PP-Layout_v1.0_ser_pretrained/
```
+最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
+
+
## 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/eval_re.py b/ppstructure/vqa/eval_re.py
deleted file mode 100644
index 68c27bad8a8e236fc16ffad21acefe7a55fde561..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/eval_re.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
-import paddle
-
-from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
-
-from xfun import XFUNDataset
-from vqa_utils import parse_args, get_bio_label_maps, print_arguments
-from data_collator import DataCollator
-from metric import re_score
-
-from ppocr.utils.logging import get_logger
-
-
-def cal_metric(re_preds, re_labels, entities):
- gt_relations = []
- for b in range(len(re_labels)):
- rel_sent = []
- for head, tail in zip(re_labels[b]["head"], re_labels[b]["tail"]):
- rel = {}
- rel["head_id"] = head
- rel["head"] = (entities[b]["start"][rel["head_id"]],
- entities[b]["end"][rel["head_id"]])
- rel["head_type"] = entities[b]["label"][rel["head_id"]]
-
- rel["tail_id"] = tail
- rel["tail"] = (entities[b]["start"][rel["tail_id"]],
- entities[b]["end"][rel["tail_id"]])
- rel["tail_type"] = entities[b]["label"][rel["tail_id"]]
-
- rel["type"] = 1
- rel_sent.append(rel)
- gt_relations.append(rel_sent)
- re_metrics = re_score(re_preds, gt_relations, mode="boundaries")
- return re_metrics
-
-
-def evaluate(model, eval_dataloader, logger, prefix=""):
- # Eval!
- logger.info("***** Running evaluation {} *****".format(prefix))
- logger.info(" Num examples = {}".format(len(eval_dataloader.dataset)))
-
- re_preds = []
- re_labels = []
- entities = []
- eval_loss = 0.0
- model.eval()
- for idx, batch in enumerate(eval_dataloader):
- with paddle.no_grad():
- outputs = model(**batch)
- loss = outputs['loss'].mean().item()
- if paddle.distributed.get_rank() == 0:
- logger.info("[Eval] process: {}/{}, loss: {:.5f}".format(
- idx, len(eval_dataloader), loss))
-
- eval_loss += loss
- re_preds.extend(outputs['pred_relations'])
- re_labels.extend(batch['relations'])
- entities.extend(batch['entities'])
- re_metrics = cal_metric(re_preds, re_labels, entities)
- re_metrics = {
- "precision": re_metrics["ALL"]["p"],
- "recall": re_metrics["ALL"]["r"],
- "f1": re_metrics["ALL"]["f1"],
- }
- model.train()
- return re_metrics
-
-
-def eval(args):
- logger = get_logger()
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
-
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
-
- model = LayoutXLMForRelationExtraction.from_pretrained(
- args.model_name_or_path)
-
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- max_seq_len=args.max_seq_length,
- pad_token_label_id=pad_token_label_id,
- contains_re=True,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.per_gpu_eval_batch_size,
- num_workers=args.num_workers,
- shuffle=False,
- collate_fn=DataCollator())
-
- results = evaluate(model, eval_dataloader, logger)
- logger.info("eval results: {}".format(results))
-
-
-if __name__ == "__main__":
- args = parse_args()
- eval(args)
diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py
deleted file mode 100644
index 95f428c721e103748d5cd722ff0e1d2bf0f09e52..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/eval_ser.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
-import random
-import time
-import copy
-import logging
-
-import argparse
-import paddle
-import numpy as np
-from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
-
-from xfun import XFUNDataset
-from losses import SERLoss
-from vqa_utils import parse_args, get_bio_label_maps, print_arguments
-
-from ppocr.utils.logging import get_logger
-
-MODELS = {
- 'LayoutXLM':
- (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
- 'LayoutLM':
- (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
-}
-
-
-def eval(args):
- logger = get_logger()
- print_arguments(args, logger)
-
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
-
- tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
- model = model_class.from_pretrained(args.model_name_or_path)
-
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- pad_token_label_id=pad_token_label_id,
- contains_re=False,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.per_gpu_eval_batch_size,
- num_workers=args.num_workers,
- use_shared_memory=True,
- collate_fn=None, )
-
- loss_class = SERLoss(len(label2id_map))
-
- results, _ = evaluate(args, model, tokenizer, loss_class, eval_dataloader,
- label2id_map, id2label_map, pad_token_label_id,
- logger)
-
- logger.info(results)
-
-
-def evaluate(args,
- model,
- tokenizer,
- loss_class,
- eval_dataloader,
- label2id_map,
- id2label_map,
- pad_token_label_id,
- logger,
- prefix=""):
-
- eval_loss = 0.0
- nb_eval_steps = 0
- preds = None
- out_label_ids = None
- model.eval()
- for idx, batch in enumerate(eval_dataloader):
- with paddle.no_grad():
- if args.ser_model_type == 'LayoutLM':
- if 'image' in batch:
- batch.pop('image')
- labels = batch.pop('labels')
- outputs = model(**batch)
- if args.ser_model_type == 'LayoutXLM':
- outputs = outputs[0]
- loss = loss_class(labels, outputs, batch['attention_mask'])
-
- loss = loss.mean()
-
- if paddle.distributed.get_rank() == 0:
- logger.info("[Eval]process: {}/{}, loss: {:.5f}".format(
- idx, len(eval_dataloader), loss.numpy()[0]))
-
- eval_loss += loss.item()
- nb_eval_steps += 1
- if preds is None:
- preds = outputs.numpy()
- out_label_ids = labels.numpy()
- else:
- preds = np.append(preds, outputs.numpy(), axis=0)
- out_label_ids = np.append(out_label_ids, labels.numpy(), axis=0)
-
- eval_loss = eval_loss / nb_eval_steps
- preds = np.argmax(preds, axis=2)
-
- # label_map = {i: label.upper() for i, label in enumerate(labels)}
-
- out_label_list = [[] for _ in range(out_label_ids.shape[0])]
- preds_list = [[] for _ in range(out_label_ids.shape[0])]
-
- for i in range(out_label_ids.shape[0]):
- for j in range(out_label_ids.shape[1]):
- if out_label_ids[i, j] != pad_token_label_id:
- out_label_list[i].append(id2label_map[out_label_ids[i][j]])
- preds_list[i].append(id2label_map[preds[i][j]])
-
- results = {
- "loss": eval_loss,
- "precision": precision_score(out_label_list, preds_list),
- "recall": recall_score(out_label_list, preds_list),
- "f1": f1_score(out_label_list, preds_list),
- }
-
- with open(
- os.path.join(args.output_dir, "test_gt.txt"), "w",
- encoding='utf-8') as fout:
- for lbl in out_label_list:
- for l in lbl:
- fout.write(l + "\t")
- fout.write("\n")
- with open(
- os.path.join(args.output_dir, "test_pred.txt"), "w",
- encoding='utf-8') as fout:
- for lbl in preds_list:
- for l in lbl:
- fout.write(l + "\t")
- fout.write("\n")
-
- report = classification_report(out_label_list, preds_list)
- logger.info("\n" + report)
-
- logger.info("***** Eval results %s *****", prefix)
- for key in sorted(results.keys()):
- logger.info(" %s = %s", key, str(results[key]))
- model.train()
- return results, preds_list
-
-
-if __name__ == "__main__":
- args = parse_args()
- eval(args)
diff --git a/ppstructure/vqa/helper/trans_xfun_data.py b/ppstructure/vqa/helper/trans_xfun_data.py
index 25b3963d8362d28ea1df4c62d1491095b8c49253..93ec98163c6cec96ec93399c1d41524200ddc499 100644
--- a/ppstructure/vqa/helper/trans_xfun_data.py
+++ b/ppstructure/vqa/helper/trans_xfun_data.py
@@ -49,4 +49,16 @@ def transfer_xfun_data(json_path=None, output_file=None):
print("===ok====")
-transfer_xfun_data("./xfun/zh.val.json", "./xfun_normalize_val.json")
+def parser_args():
+ import argparse
+ parser = argparse.ArgumentParser(description="args for paddleserving")
+ parser.add_argument(
+ "--ori_gt_path", type=str, required=True, help='origin xfun gt path')
+ parser.add_argument(
+ "--output_path", type=str, required=True, help='path to save')
+ args = parser.parse_args()
+ return args
+
+
+args = parser_args()
+transfer_xfun_data(args.ori_gt_path, args.output_path)
diff --git a/ppstructure/vqa/infer.sh b/ppstructure/vqa/infer.sh
deleted file mode 100644
index 2cd1cea4476672732b3a7f9ad97a3e42172dbb92..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/infer.sh
+++ /dev/null
@@ -1,61 +0,0 @@
-export CUDA_VISIBLE_DEVICES=6
-# python3.7 infer_ser_e2e.py \
-# --model_name_or_path "output/ser_distributed/best_model" \
-# --max_seq_length 512 \
-# --output_dir "output_res_e2e/" \
-# --infer_imgs "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val/zh_val_0.jpg"
-
-
-# python3.7 infer_ser_re_e2e.py \
-# --model_name_or_path "output/ser_distributed/best_model" \
-# --re_model_name_or_path "output/re_test/best_model" \
-# --max_seq_length 512 \
-# --output_dir "output_ser_re_e2e_train/" \
-# --infer_imgs "images/input/zh_val_21.jpg"
-
-# python3.7 infer_ser.py \
-# --model_name_or_path "output/ser_LayoutLM/best_model" \
-# --ser_model_type "LayoutLM" \
-# --output_dir "ser_LayoutLM/" \
-# --infer_imgs "images/input/zh_val_21.jpg" \
-# --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
-
-python3.7 infer_ser.py \
- --model_name_or_path "output/ser_new/best_model" \
- --ser_model_type "LayoutXLM" \
- --output_dir "ser_new/" \
- --infer_imgs "images/input/zh_val_21.jpg" \
- --ocr_json_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json"
-
-# python3.7 infer_ser_e2e.py \
-# --model_name_or_path "output/ser_new/best_model" \
-# --ser_model_type "LayoutXLM" \
-# --max_seq_length 512 \
-# --output_dir "output/ser_new/" \
-# --infer_imgs "images/input/zh_val_0.jpg"
-
-
-# python3.7 infer_ser_e2e.py \
-# --model_name_or_path "output/ser_LayoutLM/best_model" \
-# --ser_model_type "LayoutLM" \
-# --max_seq_length 512 \
-# --output_dir "output/ser_LayoutLM/" \
-# --infer_imgs "images/input/zh_val_0.jpg"
-
-# python3 infer_re.py \
-# --model_name_or_path "/ssd1/zhoujun20/VQA/PaddleOCR/ppstructure/vqa/output/re_test/best_model/" \
-# --max_seq_length 512 \
-# --eval_data_dir "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/zh.val" \
-# --eval_label_path "/ssd1/zhoujun20/VQA/data/XFUN_v1.0_data/xfun_normalize_val.json" \
-# --label_map_path 'labels/labels_ser.txt' \
-# --output_dir "output_res" \
-# --per_gpu_eval_batch_size 1 \
-# --seed 2048
-
-# python3.7 infer_ser_re_e2e.py \
-# --model_name_or_path "output/ser_LayoutLM/best_model" \
-# --ser_model_type "LayoutLM" \
-# --re_model_name_or_path "output/re_new/best_model" \
-# --max_seq_length 512 \
-# --output_dir "output_ser_re_e2e/" \
-# --infer_imgs "images/input/zh_val_21.jpg"
\ No newline at end of file
diff --git a/ppstructure/vqa/infer_re.py b/ppstructure/vqa/infer_re.py
deleted file mode 100644
index b6774e77befe6ba8954d5f552bcade86cb44e644..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/infer_re.py
+++ /dev/null
@@ -1,165 +0,0 @@
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
-import random
-
-import cv2
-import matplotlib.pyplot as plt
-import numpy as np
-import paddle
-
-from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
-
-from xfun import XFUNDataset
-from vqa_utils import parse_args, get_bio_label_maps, draw_re_results
-from data_collator import DataCollator
-
-from ppocr.utils.logging import get_logger
-
-
-def infer(args):
- os.makedirs(args.output_dir, exist_ok=True)
- logger = get_logger()
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
-
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
-
- model = LayoutXLMForRelationExtraction.from_pretrained(
- args.model_name_or_path)
-
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- max_seq_len=args.max_seq_length,
- pad_token_label_id=pad_token_label_id,
- contains_re=True,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.per_gpu_eval_batch_size,
- num_workers=8,
- shuffle=False,
- collate_fn=DataCollator())
-
- # 读取gt的oct数据
- ocr_info_list = load_ocr(args.eval_data_dir, args.eval_label_path)
-
- for idx, batch in enumerate(eval_dataloader):
- ocr_info = ocr_info_list[idx]
- image_path = ocr_info['image_path']
- ocr_info = ocr_info['ocr_info']
-
- save_img_path = os.path.join(
- args.output_dir,
- os.path.splitext(os.path.basename(image_path))[0] + "_re.jpg")
- logger.info("[Infer] process: {}/{}, save result to {}".format(
- idx, len(eval_dataloader), save_img_path))
- with paddle.no_grad():
- outputs = model(**batch)
- pred_relations = outputs['pred_relations']
-
- # 根据entity里的信息,做token解码后去过滤不要的ocr_info
- ocr_info = filter_bg_by_txt(ocr_info, batch, tokenizer)
-
- # 进行 relations 到 ocr信息的转换
- result = []
- used_tail_id = []
- for relations in pred_relations:
- for relation in relations:
- if relation['tail_id'] in used_tail_id:
- continue
- if relation['head_id'] not in ocr_info or relation[
- 'tail_id'] not in ocr_info:
- continue
- used_tail_id.append(relation['tail_id'])
- ocr_info_head = ocr_info[relation['head_id']]
- ocr_info_tail = ocr_info[relation['tail_id']]
- result.append((ocr_info_head, ocr_info_tail))
-
- img = cv2.imread(image_path)
- img_show = draw_re_results(img, result)
- cv2.imwrite(save_img_path, img_show)
-
-
-def load_ocr(img_folder, json_path):
- import json
- d = []
- with open(json_path, "r", encoding='utf-8') as fin:
- lines = fin.readlines()
- for line in lines:
- image_name, info_str = line.split("\t")
- info_dict = json.loads(info_str)
- info_dict['image_path'] = os.path.join(img_folder, image_name)
- d.append(info_dict)
- return d
-
-
-def filter_bg_by_txt(ocr_info, batch, tokenizer):
- entities = batch['entities'][0]
- input_ids = batch['input_ids'][0]
-
- new_info_dict = {}
- for i in range(len(entities['start'])):
- entitie_head = entities['start'][i]
- entitie_tail = entities['end'][i]
- word_input_ids = input_ids[entitie_head:entitie_tail].numpy().tolist()
- txt = tokenizer.convert_ids_to_tokens(word_input_ids)
- txt = tokenizer.convert_tokens_to_string(txt)
-
- for i, info in enumerate(ocr_info):
- if info['text'] == txt:
- new_info_dict[i] = info
- return new_info_dict
-
-
-def post_process(pred_relations, ocr_info, img):
- result = []
- for relations in pred_relations:
- for relation in relations:
- ocr_info_head = ocr_info[relation['head_id']]
- ocr_info_tail = ocr_info[relation['tail_id']]
- result.append((ocr_info_head, ocr_info_tail))
- return result
-
-
-def draw_re(result, image_path, output_folder):
- img = cv2.imread(image_path)
-
- from matplotlib import pyplot as plt
- for ocr_info_head, ocr_info_tail in result:
- cv2.rectangle(
- img,
- tuple(ocr_info_head['bbox'][:2]),
- tuple(ocr_info_head['bbox'][2:]), (255, 0, 0),
- thickness=2)
- cv2.rectangle(
- img,
- tuple(ocr_info_tail['bbox'][:2]),
- tuple(ocr_info_tail['bbox'][2:]), (0, 0, 255),
- thickness=2)
- center_p1 = [(ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
- (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2]
- center_p2 = [(ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
- (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2]
- cv2.line(
- img, tuple(center_p1), tuple(center_p2), (0, 255, 0), thickness=2)
- plt.imshow(img)
- plt.savefig(
- os.path.join(output_folder, os.path.basename(image_path)), dpi=600)
- # plt.show()
-
-
-if __name__ == "__main__":
- args = parse_args()
- infer(args)
diff --git a/ppstructure/vqa/infer_ser.py b/ppstructure/vqa/infer_ser.py
deleted file mode 100644
index f5fb581fa7e7613216d2e4feb8e39ed8c2541dc9..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/infer_ser.py
+++ /dev/null
@@ -1,302 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-
-import json
-import cv2
-import numpy as np
-from copy import deepcopy
-
-import paddle
-
-# relative reference
-from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
-
-MODELS = {
- 'LayoutXLM':
- (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
- 'LayoutLM':
- (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
-}
-
-
-def pad_sentences(tokenizer,
- encoded_inputs,
- max_seq_len=512,
- pad_to_max_seq_len=True,
- return_attention_mask=True,
- return_token_type_ids=True,
- return_overflowing_tokens=False,
- return_special_tokens_mask=False):
- # Padding with larger size, reshape is carried out
- max_seq_len = (
- len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
-
- needs_to_be_padded = pad_to_max_seq_len and \
- max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
-
- if needs_to_be_padded:
- difference = max_seq_len - len(encoded_inputs["input_ids"])
- if tokenizer.padding_side == 'right':
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"]) + [0] * difference
- if return_token_type_ids:
- encoded_inputs["token_type_ids"] = (
- encoded_inputs["token_type_ids"] +
- [tokenizer.pad_token_type_id] * difference)
- if return_special_tokens_mask:
- encoded_inputs["special_tokens_mask"] = encoded_inputs[
- "special_tokens_mask"] + [1] * difference
- encoded_inputs["input_ids"] = encoded_inputs[
- "input_ids"] + [tokenizer.pad_token_id] * difference
- encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
- ] * difference
- else:
- assert False, "padding_side of tokenizer just supports [\"right\"] but got {}".format(
- tokenizer.padding_side)
- else:
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"])
-
- return encoded_inputs
-
-
-def split_page(encoded_inputs, max_seq_len=512):
- """
- truncate is often used in training process
- """
- for key in encoded_inputs:
- encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
- if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
- encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
- else: # for bbox
- encoded_inputs[key] = encoded_inputs[key].reshape(
- [-1, max_seq_len, 4])
- return encoded_inputs
-
-
-def preprocess(
- tokenizer,
- ori_img,
- ocr_info,
- img_size=(224, 224),
- pad_token_label_id=-100,
- max_seq_len=512,
- add_special_ids=False,
- return_attention_mask=True, ):
- ocr_info = deepcopy(ocr_info)
- height = ori_img.shape[0]
- width = ori_img.shape[1]
-
- img = cv2.resize(ori_img,
- (224, 224)).transpose([2, 0, 1]).astype(np.float32)
-
- segment_offset_id = []
- words_list = []
- bbox_list = []
- input_ids_list = []
- token_type_ids_list = []
-
- for info in ocr_info:
- # x1, y1, x2, y2
- bbox = info["bbox"]
- bbox[0] = int(bbox[0] * 1000.0 / width)
- bbox[2] = int(bbox[2] * 1000.0 / width)
- bbox[1] = int(bbox[1] * 1000.0 / height)
- bbox[3] = int(bbox[3] * 1000.0 / height)
-
- text = info["text"]
- encode_res = tokenizer.encode(
- text, pad_to_max_seq_len=False, return_attention_mask=True)
-
- if not add_special_ids:
- # TODO: use tok.all_special_ids to remove
- encode_res["input_ids"] = encode_res["input_ids"][1:-1]
- encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
- encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
-
- input_ids_list.extend(encode_res["input_ids"])
- token_type_ids_list.extend(encode_res["token_type_ids"])
- bbox_list.extend([bbox] * len(encode_res["input_ids"]))
- words_list.append(text)
- segment_offset_id.append(len(input_ids_list))
-
- encoded_inputs = {
- "input_ids": input_ids_list,
- "token_type_ids": token_type_ids_list,
- "bbox": bbox_list,
- "attention_mask": [1] * len(input_ids_list),
- }
-
- encoded_inputs = pad_sentences(
- tokenizer,
- encoded_inputs,
- max_seq_len=max_seq_len,
- return_attention_mask=return_attention_mask)
-
- encoded_inputs = split_page(encoded_inputs)
-
- fake_bs = encoded_inputs["input_ids"].shape[0]
-
- encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
- [fake_bs] + list(img.shape))
-
- encoded_inputs["segment_offset_id"] = segment_offset_id
-
- return encoded_inputs
-
-
-def postprocess(attention_mask, preds, label_map_path):
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds = np.argmax(preds, axis=2)
-
- _, label_map = get_bio_label_maps(label_map_path)
-
- preds_list = [[] for _ in range(preds.shape[0])]
-
- # keep batch info
- for i in range(preds.shape[0]):
- for j in range(preds.shape[1]):
- if attention_mask[i][j] == 1:
- preds_list[i].append(label_map[preds[i][j]])
-
- return preds_list
-
-
-def merge_preds_list_with_ocr_info(label_map_path, ocr_info, segment_offset_id,
- preds_list):
- # must ensure the preds_list is generated from the same image
- preds = [p for pred in preds_list for p in pred]
- label2id_map, _ = get_bio_label_maps(label_map_path)
- for key in label2id_map:
- if key.startswith("I-"):
- label2id_map[key] = label2id_map["B" + key[1:]]
-
- id2label_map = dict()
- for key in label2id_map:
- val = label2id_map[key]
- if key == "O":
- id2label_map[val] = key
- if key.startswith("B-") or key.startswith("I-"):
- id2label_map[val] = key[2:]
- else:
- id2label_map[val] = key
-
- for idx in range(len(segment_offset_id)):
- if idx == 0:
- start_id = 0
- else:
- start_id = segment_offset_id[idx - 1]
-
- end_id = segment_offset_id[idx]
-
- curr_pred = preds[start_id:end_id]
- curr_pred = [label2id_map[p] for p in curr_pred]
-
- if len(curr_pred) <= 0:
- pred_id = 0
- else:
- counts = np.bincount(curr_pred)
- pred_id = np.argmax(counts)
- ocr_info[idx]["pred_id"] = int(pred_id)
- ocr_info[idx]["pred"] = id2label_map[pred_id]
- return ocr_info
-
-
-@paddle.no_grad()
-def infer(args):
- os.makedirs(args.output_dir, exist_ok=True)
-
- # init token and model
- tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
- model = model_class.from_pretrained(args.model_name_or_path)
-
- model.eval()
-
- # load ocr results json
- ocr_results = dict()
- with open(args.ocr_json_path, "r", encoding='utf-8') as fin:
- lines = fin.readlines()
- for line in lines:
- img_name, json_info = line.split("\t")
- ocr_results[os.path.basename(img_name)] = json.loads(json_info)
-
- # get infer img list
- infer_imgs = get_image_file_list(args.infer_imgs)
-
- # loop for infer
- with open(
- os.path.join(args.output_dir, "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
- save_img_path = os.path.join(args.output_dir,
- os.path.basename(img_path))
- print("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
-
- img = cv2.imread(img_path)
-
- ocr_info = ocr_results[os.path.basename(img_path)]["ocr_info"]
- inputs = preprocess(
- tokenizer=tokenizer,
- ori_img=img,
- ocr_info=ocr_info,
- max_seq_len=args.max_seq_length)
- if args.ser_model_type == 'LayoutLM':
- preds = model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
- elif args.ser_model_type == 'LayoutXLM':
- preds = model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- image=inputs["image"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
- preds = preds[0]
-
- preds = postprocess(inputs["attention_mask"], preds,
- args.label_map_path)
- ocr_info = merge_preds_list_with_ocr_info(
- args.label_map_path, ocr_info, inputs["segment_offset_id"],
- preds)
-
- fout.write(img_path + "\t" + json.dumps(
- {
- "ocr_info": ocr_info,
- }, ensure_ascii=False) + "\n")
-
- img_res = draw_ser_results(img, ocr_info)
- cv2.imwrite(save_img_path, img_res)
-
- return
-
-
-if __name__ == "__main__":
- args = parse_args()
- infer(args)
diff --git a/ppstructure/vqa/infer_ser_e2e.py b/ppstructure/vqa/infer_ser_e2e.py
deleted file mode 100644
index 33fe4dbb5e809388b135ee467d7e7c230f0eabcc..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/infer_ser_e2e.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-
-import json
-import cv2
-import numpy as np
-from copy import deepcopy
-from PIL import Image
-
-import paddle
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
-
-# relative reference
-from vqa_utils import parse_args, get_image_file_list, draw_ser_results, get_bio_label_maps
-
-from vqa_utils import pad_sentences, split_page, preprocess, postprocess, merge_preds_list_with_ocr_info
-
-MODELS = {
- 'LayoutXLM':
- (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
- 'LayoutLM':
- (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
-}
-
-
-def trans_poly_to_bbox(poly):
- x1 = np.min([p[0] for p in poly])
- x2 = np.max([p[0] for p in poly])
- y1 = np.min([p[1] for p in poly])
- y2 = np.max([p[1] for p in poly])
- return [x1, y1, x2, y2]
-
-
-def parse_ocr_info_for_ser(ocr_result):
- ocr_info = []
- for res in ocr_result:
- ocr_info.append({
- "text": res[1][0],
- "bbox": trans_poly_to_bbox(res[0]),
- "poly": res[0],
- })
- return ocr_info
-
-
-class SerPredictor(object):
- def __init__(self, args):
- self.args = args
- self.max_seq_length = args.max_seq_length
-
- # init ser token and model
- tokenizer_class, base_model_class, model_class = MODELS[
- args.ser_model_type]
- self.tokenizer = tokenizer_class.from_pretrained(
- args.model_name_or_path)
- self.model = model_class.from_pretrained(args.model_name_or_path)
- self.model.eval()
-
- # init ocr_engine
- from paddleocr import PaddleOCR
-
- self.ocr_engine = PaddleOCR(
- rec_model_dir=args.rec_model_dir,
- det_model_dir=args.det_model_dir,
- use_angle_cls=False,
- show_log=False)
- # init dict
- label2id_map, self.id2label_map = get_bio_label_maps(
- args.label_map_path)
- self.label2id_map_for_draw = dict()
- for key in label2id_map:
- if key.startswith("I-"):
- self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
- else:
- self.label2id_map_for_draw[key] = label2id_map[key]
-
- def __call__(self, img):
- ocr_result = self.ocr_engine.ocr(img, cls=False)
-
- ocr_info = parse_ocr_info_for_ser(ocr_result)
-
- inputs = preprocess(
- tokenizer=self.tokenizer,
- ori_img=img,
- ocr_info=ocr_info,
- max_seq_len=self.max_seq_length)
-
- if self.args.ser_model_type == 'LayoutLM':
- preds = self.model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
- elif self.args.ser_model_type == 'LayoutXLM':
- preds = self.model(
- input_ids=inputs["input_ids"],
- bbox=inputs["bbox"],
- image=inputs["image"],
- token_type_ids=inputs["token_type_ids"],
- attention_mask=inputs["attention_mask"])
- preds = preds[0]
-
- preds = postprocess(inputs["attention_mask"], preds, self.id2label_map)
- ocr_info = merge_preds_list_with_ocr_info(
- ocr_info, inputs["segment_offset_id"], preds,
- self.label2id_map_for_draw)
- return ocr_info, inputs
-
-
-if __name__ == "__main__":
- args = parse_args()
- os.makedirs(args.output_dir, exist_ok=True)
-
- # get infer img list
- infer_imgs = get_image_file_list(args.infer_imgs)
-
- # loop for infer
- ser_engine = SerPredictor(args)
- with open(
- os.path.join(args.output_dir, "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
- save_img_path = os.path.join(
- args.output_dir,
- os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- print("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
-
- img = cv2.imread(img_path)
-
- result, _ = ser_engine(img)
- fout.write(img_path + "\t" + json.dumps(
- {
- "ser_resule": result,
- }, ensure_ascii=False) + "\n")
-
- img_res = draw_ser_results(img, result)
- cv2.imwrite(save_img_path, img_res)
diff --git a/ppstructure/vqa/infer_ser_re_e2e.py b/ppstructure/vqa/infer_ser_re_e2e.py
deleted file mode 100644
index e24c9f69e0836d64fbe67609623e4b6409f7658c..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/infer_ser_re_e2e.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-import json
-import cv2
-import numpy as np
-from copy import deepcopy
-from PIL import Image
-
-import paddle
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForRelationExtraction
-
-# relative reference
-from vqa_utils import parse_args, get_image_file_list, draw_re_results
-from infer_ser_e2e import SerPredictor
-
-
-def make_input(ser_input, ser_result, max_seq_len=512):
- entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
-
- entities = ser_input['entities'][0]
- assert len(entities) == len(ser_result)
-
- # entities
- start = []
- end = []
- label = []
- entity_idx_dict = {}
- for i, (res, entity) in enumerate(zip(ser_result, entities)):
- if res['pred'] == 'O':
- continue
- entity_idx_dict[len(start)] = i
- start.append(entity['start'])
- end.append(entity['end'])
- label.append(entities_labels[res['pred']])
- entities = dict(start=start, end=end, label=label)
-
- # relations
- head = []
- tail = []
- for i in range(len(entities["label"])):
- for j in range(len(entities["label"])):
- if entities["label"][i] == 1 and entities["label"][j] == 2:
- head.append(i)
- tail.append(j)
-
- relations = dict(head=head, tail=tail)
-
- batch_size = ser_input["input_ids"].shape[0]
- entities_batch = []
- relations_batch = []
- for b in range(batch_size):
- entities_batch.append(entities)
- relations_batch.append(relations)
-
- ser_input['entities'] = entities_batch
- ser_input['relations'] = relations_batch
-
- ser_input.pop('segment_offset_id')
- return ser_input, entity_idx_dict
-
-
-class SerReSystem(object):
- def __init__(self, args):
- self.ser_engine = SerPredictor(args)
- self.tokenizer = LayoutXLMTokenizer.from_pretrained(
- args.re_model_name_or_path)
- self.model = LayoutXLMForRelationExtraction.from_pretrained(
- args.re_model_name_or_path)
- self.model.eval()
-
- def __call__(self, img):
- ser_result, ser_inputs = self.ser_engine(img)
- re_input, entity_idx_dict = make_input(ser_inputs, ser_result)
-
- re_result = self.model(**re_input)
-
- pred_relations = re_result['pred_relations'][0]
- # 进行 relations 到 ocr信息的转换
- result = []
- used_tail_id = []
- for relation in pred_relations:
- if relation['tail_id'] in used_tail_id:
- continue
- used_tail_id.append(relation['tail_id'])
- ocr_info_head = ser_result[entity_idx_dict[relation['head_id']]]
- ocr_info_tail = ser_result[entity_idx_dict[relation['tail_id']]]
- result.append((ocr_info_head, ocr_info_tail))
-
- return result
-
-
-if __name__ == "__main__":
- args = parse_args()
- os.makedirs(args.output_dir, exist_ok=True)
-
- # get infer img list
- infer_imgs = get_image_file_list(args.infer_imgs)
-
- # loop for infer
- ser_re_engine = SerReSystem(args)
- with open(
- os.path.join(args.output_dir, "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
- save_img_path = os.path.join(
- args.output_dir,
- os.path.splitext(os.path.basename(img_path))[0] + "_re.jpg")
- print("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
-
- img = cv2.imread(img_path)
-
- result = ser_re_engine(img)
- fout.write(img_path + "\t" + json.dumps(
- {
- "result": result,
- }, ensure_ascii=False) + "\n")
-
- img_res = draw_re_results(img, result)
- cv2.imwrite(save_img_path, img_res)
diff --git a/ppstructure/vqa/metric.py b/ppstructure/vqa/metric.py
deleted file mode 100644
index cb58370521296886670486982caf1202cf99a489..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/metric.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import re
-
-import numpy as np
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-PREFIX_CHECKPOINT_DIR = "checkpoint"
-_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
-
-
-def get_last_checkpoint(folder):
- content = os.listdir(folder)
- checkpoints = [
- path for path in content
- if _re_checkpoint.search(path) is not None and os.path.isdir(
- os.path.join(folder, path))
- ]
- if len(checkpoints) == 0:
- return
- return os.path.join(
- folder,
- max(checkpoints,
- key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
-
-
-def re_score(pred_relations, gt_relations, mode="strict"):
- """Evaluate RE predictions
-
- Args:
- pred_relations (list) : list of list of predicted relations (several relations in each sentence)
- gt_relations (list) : list of list of ground truth relations
-
- rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
- "tail": (start_idx (inclusive), end_idx (exclusive)),
- "head_type": ent_type,
- "tail_type": ent_type,
- "type": rel_type}
-
- vocab (Vocab) : dataset vocabulary
- mode (str) : in 'strict' or 'boundaries'"""
-
- assert mode in ["strict", "boundaries"]
-
- relation_types = [v for v in [0, 1] if not v == 0]
- scores = {
- rel: {
- "tp": 0,
- "fp": 0,
- "fn": 0
- }
- for rel in relation_types + ["ALL"]
- }
-
- # Count GT relations and Predicted relations
- n_sents = len(gt_relations)
- n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
- n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
-
- # Count TP, FP and FN per type
- for pred_sent, gt_sent in zip(pred_relations, gt_relations):
- for rel_type in relation_types:
- # strict mode takes argument types into account
- if mode == "strict":
- pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in pred_sent if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in gt_sent if rel["type"] == rel_type}
-
- # boundaries mode only takes argument spans into account
- elif mode == "boundaries":
- pred_rels = {(rel["head"], rel["tail"])
- for rel in pred_sent if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["tail"])
- for rel in gt_sent if rel["type"] == rel_type}
-
- scores[rel_type]["tp"] += len(pred_rels & gt_rels)
- scores[rel_type]["fp"] += len(pred_rels - gt_rels)
- scores[rel_type]["fn"] += len(gt_rels - pred_rels)
-
- # Compute per entity Precision / Recall / F1
- for rel_type in scores.keys():
- if scores[rel_type]["tp"]:
- scores[rel_type]["p"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fp"] + scores[rel_type]["tp"])
- scores[rel_type]["r"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fn"] + scores[rel_type]["tp"])
- else:
- scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
-
- if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
- scores[rel_type]["f1"] = (
- 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
- (scores[rel_type]["p"] + scores[rel_type]["r"]))
- else:
- scores[rel_type]["f1"] = 0
-
- # Compute micro F1 Scores
- tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
- fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
- fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
-
- if tp:
- precision = tp / (tp + fp)
- recall = tp / (tp + fn)
- f1 = 2 * precision * recall / (precision + recall)
-
- else:
- precision, recall, f1 = 0, 0, 0
-
- scores["ALL"]["p"] = precision
- scores["ALL"]["r"] = recall
- scores["ALL"]["f1"] = f1
- scores["ALL"]["tp"] = tp
- scores["ALL"]["fp"] = fp
- scores["ALL"]["fn"] = fn
-
- # Compute Macro F1 Scores
- scores["ALL"]["Macro_f1"] = np.mean(
- [scores[ent_type]["f1"] for ent_type in relation_types])
- scores["ALL"]["Macro_p"] = np.mean(
- [scores[ent_type]["p"] for ent_type in relation_types])
- scores["ALL"]["Macro_r"] = np.mean(
- [scores[ent_type]["r"] for ent_type in relation_types])
-
- # logger.info(f"RE Evaluation in *** {mode.upper()} *** mode")
-
- # logger.info(
- # "processed {} sentences with {} relations; found: {} relations; correct: {}.".format(
- # n_sents, n_rels, n_found, tp
- # )
- # )
- # logger.info(
- # "\tALL\t TP: {};\tFP: {};\tFN: {}".format(scores["ALL"]["tp"], scores["ALL"]["fp"], scores["ALL"]["fn"])
- # )
- # logger.info("\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(precision, recall, f1))
- # logger.info(
- # "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
- # scores["ALL"]["Macro_p"], scores["ALL"]["Macro_r"], scores["ALL"]["Macro_f1"]
- # )
- # )
-
- # for rel_type in relation_types:
- # logger.info(
- # "\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
- # rel_type,
- # scores[rel_type]["tp"],
- # scores[rel_type]["fp"],
- # scores[rel_type]["fn"],
- # scores[rel_type]["p"],
- # scores[rel_type]["r"],
- # scores[rel_type]["f1"],
- # scores[rel_type]["tp"] + scores[rel_type]["fp"],
- # )
- # )
-
- return scores
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
index 9c935ae619024c9f47ced820eae35a3a1c976953..0042ec0baedcc3e7bbecb922d10b93c95219219d 100644
--- a/ppstructure/vqa/requirements.txt
+++ b/ppstructure/vqa/requirements.txt
@@ -1,3 +1,4 @@
sentencepiece
yacs
-seqeval
\ No newline at end of file
+seqeval
+paddlenlp>=2.2.1
\ No newline at end of file
diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py
deleted file mode 100644
index eeff2bfbbe466b29b8b46e83058e2199fd5cafed..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/train_re.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
-import random
-import time
-import numpy as np
-import paddle
-
-from paddlenlp.transformers import LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForRelationExtraction
-
-from xfun import XFUNDataset
-from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
-from data_collator import DataCollator
-from eval_re import evaluate
-
-from ppocr.utils.logging import get_logger
-
-
-def train(args):
- logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
- rank = paddle.distributed.get_rank()
- distributed = paddle.distributed.get_world_size() > 1
-
- print_arguments(args, logger)
-
- # Added here for reproducibility (even between python 2 and 3)
- set_seed(args.seed)
-
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
-
- # dist mode
- if distributed:
- paddle.distributed.init_parallel_env()
-
- tokenizer = LayoutXLMTokenizer.from_pretrained(args.model_name_or_path)
- if not args.resume:
- model = LayoutXLMModel.from_pretrained(args.model_name_or_path)
- model = LayoutXLMForRelationExtraction(model, dropout=None)
- logger.info('train from scratch')
- else:
- logger.info('resume from {}'.format(args.model_name_or_path))
- model = LayoutXLMForRelationExtraction.from_pretrained(
- args.model_name_or_path)
-
- # dist mode
- if distributed:
- model = paddle.DataParallel(model)
-
- train_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.train_data_dir,
- label_path=args.train_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- max_seq_len=args.max_seq_length,
- pad_token_label_id=pad_token_label_id,
- contains_re=True,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- max_seq_len=args.max_seq_length,
- pad_token_label_id=pad_token_label_id,
- contains_re=True,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- train_sampler = paddle.io.DistributedBatchSampler(
- train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
-
- train_dataloader = paddle.io.DataLoader(
- train_dataset,
- batch_sampler=train_sampler,
- num_workers=args.num_workers,
- use_shared_memory=True,
- collate_fn=DataCollator())
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.per_gpu_eval_batch_size,
- num_workers=args.num_workers,
- shuffle=False,
- collate_fn=DataCollator())
-
- t_total = len(train_dataloader) * args.num_train_epochs
-
- # build linear decay with warmup lr sch
- lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
- learning_rate=args.learning_rate,
- decay_steps=t_total,
- end_lr=0.0,
- power=1.0)
- if args.warmup_steps > 0:
- lr_scheduler = paddle.optimizer.lr.LinearWarmup(
- lr_scheduler,
- args.warmup_steps,
- start_lr=0,
- end_lr=args.learning_rate, )
- grad_clip = paddle.nn.ClipGradByNorm(clip_norm=10)
- optimizer = paddle.optimizer.Adam(
- learning_rate=args.learning_rate,
- parameters=model.parameters(),
- epsilon=args.adam_epsilon,
- grad_clip=grad_clip,
- weight_decay=args.weight_decay)
-
- # Train!
- logger.info("***** Running training *****")
- logger.info(" Num examples = {}".format(len(train_dataset)))
- logger.info(" Num Epochs = {}".format(args.num_train_epochs))
- logger.info(" Instantaneous batch size per GPU = {}".format(
- args.per_gpu_train_batch_size))
- logger.info(
- " Total train batch size (w. parallel, distributed & accumulation) = {}".
- format(args.per_gpu_train_batch_size *
- paddle.distributed.get_world_size()))
- logger.info(" Total optimization steps = {}".format(t_total))
-
- global_step = 0
- model.clear_gradients()
- train_dataloader_len = len(train_dataloader)
- best_metirc = {'f1': 0}
- model.train()
-
- train_reader_cost = 0.0
- train_run_cost = 0.0
- total_samples = 0
- reader_start = time.time()
-
- print_step = 1
-
- for epoch in range(int(args.num_train_epochs)):
- for step, batch in enumerate(train_dataloader):
- train_reader_cost += time.time() - reader_start
- train_start = time.time()
- outputs = model(**batch)
- train_run_cost += time.time() - train_start
- # model outputs are always tuple in ppnlp (see doc)
- loss = outputs['loss']
- loss = loss.mean()
-
- loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- # lr_scheduler.step() # Update learning rate schedule
-
- global_step += 1
- total_samples += batch['image'].shape[0]
-
- if rank == 0 and step % print_step == 0:
- logger.info(
- "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
- format(epoch, args.num_train_epochs, step,
- train_dataloader_len, global_step,
- np.mean(loss.numpy()),
- optimizer.get_lr(), train_reader_cost / print_step, (
- train_reader_cost + train_run_cost) / print_step,
- total_samples / print_step, total_samples / (
- train_reader_cost + train_run_cost)))
-
- train_reader_cost = 0.0
- train_run_cost = 0.0
- total_samples = 0
-
- if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
- # Log metrics
- # Only evaluate when single GPU otherwise metrics may not average well
- results = evaluate(model, eval_dataloader, logger)
- if results['f1'] >= best_metirc['f1']:
- best_metirc = results
- output_dir = os.path.join(args.output_dir, "best_model")
- os.makedirs(output_dir, exist_ok=True)
- if distributed:
- model._layers.save_pretrained(output_dir)
- else:
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(args,
- os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to {}".format(
- output_dir))
- logger.info("eval results: {}".format(results))
- logger.info("best_metirc: {}".format(best_metirc))
- reader_start = time.time()
-
- if rank == 0:
- # Save model checkpoint
- output_dir = os.path.join(args.output_dir, "latest_model")
- os.makedirs(output_dir, exist_ok=True)
- if distributed:
- model._layers.save_pretrained(output_dir)
- else:
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(args, os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to {}".format(output_dir))
- logger.info("best_metirc: {}".format(best_metirc))
-
-
-if __name__ == "__main__":
- args = parse_args()
- os.makedirs(args.output_dir, exist_ok=True)
- train(args)
diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py
deleted file mode 100644
index 226172050e9a5ea3b7c6534444ef24278de07043..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/train_ser.py
+++ /dev/null
@@ -1,248 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import sys
-
-__dir__ = os.path.dirname(os.path.abspath(__file__))
-sys.path.append(__dir__)
-sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
-
-import random
-import time
-import copy
-import logging
-
-import argparse
-import paddle
-import numpy as np
-from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
-from paddlenlp.transformers import LayoutXLMModel, LayoutXLMTokenizer, LayoutXLMForTokenClassification
-from paddlenlp.transformers import LayoutLMModel, LayoutLMTokenizer, LayoutLMForTokenClassification
-
-from xfun import XFUNDataset
-from vqa_utils import parse_args, get_bio_label_maps, print_arguments, set_seed
-from eval_ser import evaluate
-from losses import SERLoss
-from ppocr.utils.logging import get_logger
-
-MODELS = {
- 'LayoutXLM':
- (LayoutXLMTokenizer, LayoutXLMModel, LayoutXLMForTokenClassification),
- 'LayoutLM':
- (LayoutLMTokenizer, LayoutLMModel, LayoutLMForTokenClassification)
-}
-
-
-def train(args):
- os.makedirs(args.output_dir, exist_ok=True)
- rank = paddle.distributed.get_rank()
- distributed = paddle.distributed.get_world_size() > 1
-
- logger = get_logger(log_file=os.path.join(args.output_dir, "train.log"))
- print_arguments(args, logger)
-
- label2id_map, id2label_map = get_bio_label_maps(args.label_map_path)
- loss_class = SERLoss(len(label2id_map))
-
- pad_token_label_id = loss_class.ignore_index
-
- # dist mode
- if distributed:
- paddle.distributed.init_parallel_env()
-
- tokenizer_class, base_model_class, model_class = MODELS[args.ser_model_type]
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
- if not args.resume:
- base_model = base_model_class.from_pretrained(args.model_name_or_path)
- model = model_class(
- base_model, num_classes=len(label2id_map), dropout=None)
- logger.info('train from scratch')
- else:
- logger.info('resume from {}'.format(args.model_name_or_path))
- model = model_class.from_pretrained(args.model_name_or_path)
-
- # dist mode
- if distributed:
- model = paddle.DataParallel(model)
-
- train_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.train_data_dir,
- label_path=args.train_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- pad_token_label_id=pad_token_label_id,
- contains_re=False,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
- eval_dataset = XFUNDataset(
- tokenizer,
- data_dir=args.eval_data_dir,
- label_path=args.eval_label_path,
- label2id_map=label2id_map,
- img_size=(224, 224),
- pad_token_label_id=pad_token_label_id,
- contains_re=False,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all')
-
- train_sampler = paddle.io.DistributedBatchSampler(
- train_dataset, batch_size=args.per_gpu_train_batch_size, shuffle=True)
-
- train_dataloader = paddle.io.DataLoader(
- train_dataset,
- batch_sampler=train_sampler,
- num_workers=args.num_workers,
- use_shared_memory=True,
- collate_fn=None, )
-
- eval_dataloader = paddle.io.DataLoader(
- eval_dataset,
- batch_size=args.per_gpu_eval_batch_size,
- num_workers=args.num_workers,
- use_shared_memory=True,
- collate_fn=None, )
-
- t_total = len(train_dataloader) * args.num_train_epochs
-
- # build linear decay with warmup lr sch
- lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
- learning_rate=args.learning_rate,
- decay_steps=t_total,
- end_lr=0.0,
- power=1.0)
- if args.warmup_steps > 0:
- lr_scheduler = paddle.optimizer.lr.LinearWarmup(
- lr_scheduler,
- args.warmup_steps,
- start_lr=0,
- end_lr=args.learning_rate, )
-
- optimizer = paddle.optimizer.AdamW(
- learning_rate=lr_scheduler,
- parameters=model.parameters(),
- epsilon=args.adam_epsilon,
- weight_decay=args.weight_decay)
-
- # Train!
- logger.info("***** Running training *****")
- logger.info(" Num examples = %d", len(train_dataset))
- logger.info(" Num Epochs = %d", args.num_train_epochs)
- logger.info(" Instantaneous batch size per GPU = %d",
- args.per_gpu_train_batch_size)
- logger.info(
- " Total train batch size (w. parallel, distributed) = %d",
- args.per_gpu_train_batch_size * paddle.distributed.get_world_size(), )
- logger.info(" Total optimization steps = %d", t_total)
-
- global_step = 0
- tr_loss = 0.0
- set_seed(args.seed)
- best_metrics = None
-
- train_reader_cost = 0.0
- train_run_cost = 0.0
- total_samples = 0
- reader_start = time.time()
-
- print_step = 1
- model.train()
- for epoch_id in range(args.num_train_epochs):
- for step, batch in enumerate(train_dataloader):
- train_reader_cost += time.time() - reader_start
-
- if args.ser_model_type == 'LayoutLM':
- if 'image' in batch:
- batch.pop('image')
- labels = batch.pop('labels')
-
- train_start = time.time()
- outputs = model(**batch)
- train_run_cost += time.time() - train_start
- if args.ser_model_type == 'LayoutXLM':
- outputs = outputs[0]
- loss = loss_class(labels, outputs, batch['attention_mask'])
-
- # model outputs are always tuple in ppnlp (see doc)
- loss = loss.mean()
- loss.backward()
- tr_loss += loss.item()
- optimizer.step()
- lr_scheduler.step() # Update learning rate schedule
- optimizer.clear_grad()
- global_step += 1
- total_samples += batch['input_ids'].shape[0]
-
- if rank == 0 and step % print_step == 0:
- logger.info(
- "epoch: [{}/{}], iter: [{}/{}], global_step:{}, train loss: {:.6f}, lr: {:.6f}, avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} images/sec".
- format(epoch_id, args.num_train_epochs, step,
- len(train_dataloader), global_step,
- loss.numpy()[0],
- lr_scheduler.get_lr(), train_reader_cost /
- print_step, (train_reader_cost + train_run_cost) /
- print_step, total_samples / print_step, total_samples
- / (train_reader_cost + train_run_cost)))
-
- train_reader_cost = 0.0
- train_run_cost = 0.0
- total_samples = 0
-
- if rank == 0 and args.eval_steps > 0 and global_step % args.eval_steps == 0 and args.evaluate_during_training:
- # Log metrics
- # Only evaluate when single GPU otherwise metrics may not average well
- results, _ = evaluate(args, model, tokenizer, loss_class,
- eval_dataloader, label2id_map,
- id2label_map, pad_token_label_id, logger)
-
- if best_metrics is None or results["f1"] >= best_metrics["f1"]:
- best_metrics = copy.deepcopy(results)
- output_dir = os.path.join(args.output_dir, "best_model")
- os.makedirs(output_dir, exist_ok=True)
- if distributed:
- model._layers.save_pretrained(output_dir)
- else:
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(args,
- os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to {}".format(
- output_dir))
-
- logger.info("[epoch {}/{}][iter: {}/{}] results: {}".format(
- epoch_id, args.num_train_epochs, step,
- len(train_dataloader), results))
- if best_metrics is not None:
- logger.info("best metrics: {}".format(best_metrics))
- reader_start = time.time()
- if rank == 0:
- # Save model checkpoint
- output_dir = os.path.join(args.output_dir, "latest_model")
- os.makedirs(output_dir, exist_ok=True)
- if distributed:
- model._layers.save_pretrained(output_dir)
- else:
- model.save_pretrained(output_dir)
- tokenizer.save_pretrained(output_dir)
- paddle.save(args, os.path.join(output_dir, "training_args.bin"))
- logger.info("Saving model checkpoint to {}".format(output_dir))
- return global_step, tr_loss / global_step
-
-
-if __name__ == "__main__":
- args = parse_args()
- train(args)
diff --git a/ppstructure/vqa/vqa_utils.py b/ppstructure/vqa/vqa_utils.py
deleted file mode 100644
index b9f2edc860b1ce48c22bf602cef48466c357834f..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/vqa_utils.py
+++ /dev/null
@@ -1,400 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import argparse
-import cv2
-import random
-import numpy as np
-import imghdr
-from copy import deepcopy
-
-import paddle
-
-from PIL import Image, ImageDraw, ImageFont
-
-
-def set_seed(seed):
- random.seed(seed)
- np.random.seed(seed)
- paddle.seed(seed)
-
-
-def get_bio_label_maps(label_map_path):
- with open(label_map_path, "r", encoding='utf-8') as fin:
- lines = fin.readlines()
- lines = [line.strip() for line in lines]
- if "O" not in lines:
- lines.insert(0, "O")
- labels = []
- for line in lines:
- if line == "O":
- labels.append("O")
- else:
- labels.append("B-" + line)
- labels.append("I-" + line)
- label2id_map = {label: idx for idx, label in enumerate(labels)}
- id2label_map = {idx: label for idx, label in enumerate(labels)}
- return label2id_map, id2label_map
-
-
-def get_image_file_list(img_file):
- imgs_lists = []
- if img_file is None or not os.path.exists(img_file):
- raise Exception("not found any img file in {}".format(img_file))
-
- img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
- if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
- imgs_lists.append(img_file)
- elif os.path.isdir(img_file):
- for single_file in os.listdir(img_file):
- file_path = os.path.join(img_file, single_file)
- if os.path.isfile(file_path) and imghdr.what(file_path) in img_end:
- imgs_lists.append(file_path)
- if len(imgs_lists) == 0:
- raise Exception("not found any img file in {}".format(img_file))
- imgs_lists = sorted(imgs_lists)
- return imgs_lists
-
-
-def draw_ser_results(image,
- ocr_results,
- font_path="../../doc/fonts/simfang.ttf",
- font_size=18):
- np.random.seed(2021)
- color = (np.random.permutation(range(255)),
- np.random.permutation(range(255)),
- np.random.permutation(range(255)))
- color_map = {
- idx: (color[0][idx], color[1][idx], color[2][idx])
- for idx in range(1, 255)
- }
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
- img_new = image.copy()
- draw = ImageDraw.Draw(img_new)
-
- font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
- for ocr_info in ocr_results:
- if ocr_info["pred_id"] not in color_map:
- continue
- color = color_map[ocr_info["pred_id"]]
- text = "{}: {}".format(ocr_info["pred"], ocr_info["text"])
-
- draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color)
-
- img_new = Image.blend(image, img_new, 0.5)
- return np.array(img_new)
-
-
-def draw_box_txt(bbox, text, draw, font, font_size, color):
- # draw ocr results outline
- bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
- draw.rectangle(bbox, fill=color)
-
- # draw ocr results
- start_y = max(0, bbox[0][1] - font_size)
- tw = font.getsize(text)[0]
- draw.rectangle(
- [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)],
- fill=(0, 0, 255))
- draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
-
-
-def draw_re_results(image,
- result,
- font_path="../../doc/fonts/simfang.ttf",
- font_size=18):
- np.random.seed(0)
- if isinstance(image, np.ndarray):
- image = Image.fromarray(image)
- img_new = image.copy()
- draw = ImageDraw.Draw(img_new)
-
- font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
- color_head = (0, 0, 255)
- color_tail = (255, 0, 0)
- color_line = (0, 255, 0)
-
- for ocr_info_head, ocr_info_tail in result:
- draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font,
- font_size, color_head)
- draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font,
- font_size, color_tail)
-
- center_head = (
- (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
- (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
- center_tail = (
- (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
- (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
-
- draw.line([center_head, center_tail], fill=color_line, width=5)
-
- img_new = Image.blend(image, img_new, 0.5)
- return np.array(img_new)
-
-
-# pad sentences
-def pad_sentences(tokenizer,
- encoded_inputs,
- max_seq_len=512,
- pad_to_max_seq_len=True,
- return_attention_mask=True,
- return_token_type_ids=True,
- return_overflowing_tokens=False,
- return_special_tokens_mask=False):
- # Padding with larger size, reshape is carried out
- max_seq_len = (
- len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len
-
- needs_to_be_padded = pad_to_max_seq_len and \
- max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
-
- if needs_to_be_padded:
- difference = max_seq_len - len(encoded_inputs["input_ids"])
- if tokenizer.padding_side == 'right':
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"]) + [0] * difference
- if return_token_type_ids:
- encoded_inputs["token_type_ids"] = (
- encoded_inputs["token_type_ids"] +
- [tokenizer.pad_token_type_id] * difference)
- if return_special_tokens_mask:
- encoded_inputs["special_tokens_mask"] = encoded_inputs[
- "special_tokens_mask"] + [1] * difference
- encoded_inputs["input_ids"] = encoded_inputs[
- "input_ids"] + [tokenizer.pad_token_id] * difference
- encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0]
- ] * difference
- else:
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"])
-
- return encoded_inputs
-
-
-def split_page(encoded_inputs, max_seq_len=512):
- """
- truncate is often used in training process
- """
- for key in encoded_inputs:
- if key == 'entities':
- encoded_inputs[key] = [encoded_inputs[key]]
- continue
- encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key])
- if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on
- encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len])
- else: # for bbox
- encoded_inputs[key] = encoded_inputs[key].reshape(
- [-1, max_seq_len, 4])
- return encoded_inputs
-
-
-def preprocess(
- tokenizer,
- ori_img,
- ocr_info,
- img_size=(224, 224),
- pad_token_label_id=-100,
- max_seq_len=512,
- add_special_ids=False,
- return_attention_mask=True, ):
- ocr_info = deepcopy(ocr_info)
- height = ori_img.shape[0]
- width = ori_img.shape[1]
-
- img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32)
-
- segment_offset_id = []
- words_list = []
- bbox_list = []
- input_ids_list = []
- token_type_ids_list = []
- entities = []
-
- for info in ocr_info:
- # x1, y1, x2, y2
- bbox = info["bbox"]
- bbox[0] = int(bbox[0] * 1000.0 / width)
- bbox[2] = int(bbox[2] * 1000.0 / width)
- bbox[1] = int(bbox[1] * 1000.0 / height)
- bbox[3] = int(bbox[3] * 1000.0 / height)
-
- text = info["text"]
- encode_res = tokenizer.encode(
- text, pad_to_max_seq_len=False, return_attention_mask=True)
-
- if not add_special_ids:
- # TODO: use tok.all_special_ids to remove
- encode_res["input_ids"] = encode_res["input_ids"][1:-1]
- encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1]
- encode_res["attention_mask"] = encode_res["attention_mask"][1:-1]
-
- # for re
- entities.append({
- "start": len(input_ids_list),
- "end": len(input_ids_list) + len(encode_res["input_ids"]),
- "label": "O",
- })
-
- input_ids_list.extend(encode_res["input_ids"])
- token_type_ids_list.extend(encode_res["token_type_ids"])
- bbox_list.extend([bbox] * len(encode_res["input_ids"]))
- words_list.append(text)
- segment_offset_id.append(len(input_ids_list))
-
- encoded_inputs = {
- "input_ids": input_ids_list,
- "token_type_ids": token_type_ids_list,
- "bbox": bbox_list,
- "attention_mask": [1] * len(input_ids_list),
- "entities": entities
- }
-
- encoded_inputs = pad_sentences(
- tokenizer,
- encoded_inputs,
- max_seq_len=max_seq_len,
- return_attention_mask=return_attention_mask)
-
- encoded_inputs = split_page(encoded_inputs)
-
- fake_bs = encoded_inputs["input_ids"].shape[0]
-
- encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand(
- [fake_bs] + list(img.shape))
-
- encoded_inputs["segment_offset_id"] = segment_offset_id
-
- return encoded_inputs
-
-
-def postprocess(attention_mask, preds, id2label_map):
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- preds = np.argmax(preds, axis=2)
-
- preds_list = [[] for _ in range(preds.shape[0])]
-
- # keep batch info
- for i in range(preds.shape[0]):
- for j in range(preds.shape[1]):
- if attention_mask[i][j] == 1:
- preds_list[i].append(id2label_map[preds[i][j]])
-
- return preds_list
-
-
-def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list,
- label2id_map_for_draw):
- # must ensure the preds_list is generated from the same image
- preds = [p for pred in preds_list for p in pred]
-
- id2label_map = dict()
- for key in label2id_map_for_draw:
- val = label2id_map_for_draw[key]
- if key == "O":
- id2label_map[val] = key
- if key.startswith("B-") or key.startswith("I-"):
- id2label_map[val] = key[2:]
- else:
- id2label_map[val] = key
-
- for idx in range(len(segment_offset_id)):
- if idx == 0:
- start_id = 0
- else:
- start_id = segment_offset_id[idx - 1]
-
- end_id = segment_offset_id[idx]
-
- curr_pred = preds[start_id:end_id]
- curr_pred = [label2id_map_for_draw[p] for p in curr_pred]
-
- if len(curr_pred) <= 0:
- pred_id = 0
- else:
- counts = np.bincount(curr_pred)
- pred_id = np.argmax(counts)
- ocr_info[idx]["pred_id"] = int(pred_id)
- ocr_info[idx]["pred"] = id2label_map[int(pred_id)]
- return ocr_info
-
-
-def print_arguments(args, logger=None):
- print_func = logger.info if logger is not None else print
- """print arguments"""
- print_func('----------- Configuration Arguments -----------')
- for arg, value in sorted(vars(args).items()):
- print_func('%s: %s' % (arg, value))
- print_func('------------------------------------------------')
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- # Required parameters
- # yapf: disable
- parser.add_argument("--model_name_or_path",
- default=None, type=str, required=True,)
- parser.add_argument("--ser_model_type",
- default='LayoutXLM', type=str)
- parser.add_argument("--re_model_name_or_path",
- default=None, type=str, required=False,)
- parser.add_argument("--train_data_dir", default=None,
- type=str, required=False,)
- parser.add_argument("--train_label_path", default=None,
- type=str, required=False,)
- parser.add_argument("--eval_data_dir", default=None,
- type=str, required=False,)
- parser.add_argument("--eval_label_path", default=None,
- type=str, required=False,)
- parser.add_argument("--output_dir", default=None, type=str, required=True,)
- parser.add_argument("--max_seq_length", default=512, type=int,)
- parser.add_argument("--evaluate_during_training", action="store_true",)
- parser.add_argument("--num_workers", default=8, type=int,)
- parser.add_argument("--per_gpu_train_batch_size", default=8,
- type=int, help="Batch size per GPU/CPU for training.",)
- parser.add_argument("--per_gpu_eval_batch_size", default=8,
- type=int, help="Batch size per GPU/CPU for eval.",)
- parser.add_argument("--learning_rate", default=5e-5,
- type=float, help="The initial learning rate for Adam.",)
- parser.add_argument("--weight_decay", default=0.0,
- type=float, help="Weight decay if we apply some.",)
- parser.add_argument("--adam_epsilon", default=1e-8,
- type=float, help="Epsilon for Adam optimizer.",)
- parser.add_argument("--max_grad_norm", default=1.0,
- type=float, help="Max gradient norm.",)
- parser.add_argument("--num_train_epochs", default=3, type=int,
- help="Total number of training epochs to perform.",)
- parser.add_argument("--warmup_steps", default=0, type=int,
- help="Linear warmup over warmup_steps.",)
- parser.add_argument("--eval_steps", type=int, default=10,
- help="eval every X updates steps.",)
- parser.add_argument("--seed", type=int, default=2048,
- help="random seed for initialization",)
-
- parser.add_argument("--rec_model_dir", default=None, type=str, )
- parser.add_argument("--det_model_dir", default=None, type=str, )
- parser.add_argument(
- "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, )
- parser.add_argument("--infer_imgs", default=None, type=str, required=False)
- parser.add_argument("--resume", action='store_true')
- parser.add_argument("--ocr_json_path", default=None,
- type=str, required=False, help="ocr prediction results")
- # yapf: enable
- args = parser.parse_args()
- return args
diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py
deleted file mode 100644
index f5dbe507e8f6d22087d7913241f7365cbede9bdf..0000000000000000000000000000000000000000
--- a/ppstructure/vqa/xfun.py
+++ /dev/null
@@ -1,464 +0,0 @@
-# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import os
-import cv2
-import numpy as np
-import paddle
-import copy
-from paddle.io import Dataset
-
-__all__ = ["XFUNDataset"]
-
-
-class XFUNDataset(Dataset):
- """
- Example:
- print("=====begin to build dataset=====")
- from paddlenlp.transformers import LayoutXLMTokenizer
- tokenizer = LayoutXLMTokenizer.from_pretrained("/paddle/models/transformers/layoutxlm-base-paddle/")
- tok_res = tokenizer.tokenize("Maribyrnong")
- # res = tokenizer.convert_ids_to_tokens(val_data["input_ids"][0])
- dataset = XfunDatasetForSer(
- tokenizer,
- data_dir="./zh.val/",
- label_path="zh.val/xfun_normalize_val.json",
- img_size=(224,224))
- print(len(dataset))
-
- data = dataset[0]
- print(data.keys())
- print("input_ids: ", data["input_ids"])
- print("labels: ", data["labels"])
- print("token_type_ids: ", data["token_type_ids"])
- print("words_list: ", data["words_list"])
- print("image shape: ", data["image"].shape)
- """
-
- def __init__(self,
- tokenizer,
- data_dir,
- label_path,
- contains_re=False,
- label2id_map=None,
- img_size=(224, 224),
- pad_token_label_id=None,
- add_special_ids=False,
- return_attention_mask=True,
- load_mode='all',
- max_seq_len=512):
- super().__init__()
- self.tokenizer = tokenizer
- self.data_dir = data_dir
- self.label_path = label_path
- self.contains_re = contains_re
- self.label2id_map = label2id_map
- self.img_size = img_size
- self.pad_token_label_id = pad_token_label_id
- self.add_special_ids = add_special_ids
- self.return_attention_mask = return_attention_mask
- self.load_mode = load_mode
- self.max_seq_len = max_seq_len
-
- if self.pad_token_label_id is None:
- self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index
-
- self.all_lines = self.read_all_lines()
-
- self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
- self.return_keys = {
- 'bbox': {
- 'type': 'np',
- 'dtype': 'int64'
- },
- 'input_ids': {
- 'type': 'np',
- 'dtype': 'int64'
- },
- 'labels': {
- 'type': 'np',
- 'dtype': 'int64'
- },
- 'attention_mask': {
- 'type': 'np',
- 'dtype': 'int64'
- },
- 'image': {
- 'type': 'np',
- 'dtype': 'float32'
- },
- 'token_type_ids': {
- 'type': 'np',
- 'dtype': 'int64'
- },
- 'entities': {
- 'type': 'dict'
- },
- 'relations': {
- 'type': 'dict'
- }
- }
-
- if load_mode == "all":
- self.encoded_inputs_all = self._parse_label_file_all()
-
- def pad_sentences(self,
- encoded_inputs,
- max_seq_len=512,
- pad_to_max_seq_len=True,
- return_attention_mask=True,
- return_token_type_ids=True,
- truncation_strategy="longest_first",
- return_overflowing_tokens=False,
- return_special_tokens_mask=False):
- # Padding
- needs_to_be_padded = pad_to_max_seq_len and \
- max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
-
- if needs_to_be_padded:
- difference = max_seq_len - len(encoded_inputs["input_ids"])
- if self.tokenizer.padding_side == 'right':
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"]) + [0] * difference
- if return_token_type_ids:
- encoded_inputs["token_type_ids"] = (
- encoded_inputs["token_type_ids"] +
- [self.tokenizer.pad_token_type_id] * difference)
- if return_special_tokens_mask:
- encoded_inputs["special_tokens_mask"] = encoded_inputs[
- "special_tokens_mask"] + [1] * difference
- encoded_inputs["input_ids"] = encoded_inputs[
- "input_ids"] + [self.tokenizer.pad_token_id] * difference
- encoded_inputs["labels"] = encoded_inputs[
- "labels"] + [self.pad_token_label_id] * difference
- encoded_inputs["bbox"] = encoded_inputs[
- "bbox"] + [[0, 0, 0, 0]] * difference
- elif self.tokenizer.padding_side == 'left':
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [0] * difference + [
- 1
- ] * len(encoded_inputs["input_ids"])
- if return_token_type_ids:
- encoded_inputs["token_type_ids"] = (
- [self.tokenizer.pad_token_type_id] * difference +
- encoded_inputs["token_type_ids"])
- if return_special_tokens_mask:
- encoded_inputs["special_tokens_mask"] = [
- 1
- ] * difference + encoded_inputs["special_tokens_mask"]
- encoded_inputs["input_ids"] = [
- self.tokenizer.pad_token_id
- ] * difference + encoded_inputs["input_ids"]
- encoded_inputs["labels"] = [
- self.pad_token_label_id
- ] * difference + encoded_inputs["labels"]
- encoded_inputs["bbox"] = [
- [0, 0, 0, 0]
- ] * difference + encoded_inputs["bbox"]
- else:
- if return_attention_mask:
- encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[
- "input_ids"])
-
- return encoded_inputs
-
- def truncate_inputs(self, encoded_inputs, max_seq_len=512):
- for key in encoded_inputs:
- if key == "sample_id":
- continue
- length = min(len(encoded_inputs[key]), max_seq_len)
- encoded_inputs[key] = encoded_inputs[key][:length]
- return encoded_inputs
-
- def read_all_lines(self, ):
- with open(self.label_path, "r", encoding='utf-8') as fin:
- lines = fin.readlines()
- return lines
-
- def _parse_label_file_all(self):
- """
- parse all samples
- """
- encoded_inputs_all = []
- for line in self.all_lines:
- encoded_inputs_all.extend(self._parse_label_file(line))
- return encoded_inputs_all
-
- def _parse_label_file(self, line):
- """
- parse single sample
- """
-
- image_name, info_str = line.split("\t")
- image_path = os.path.join(self.data_dir, image_name)
-
- def add_imgge_path(x):
- x['image_path'] = image_path
- return x
-
- encoded_inputs = self._read_encoded_inputs_sample(info_str)
- if self.contains_re:
- encoded_inputs = self._chunk_re(encoded_inputs)
- else:
- encoded_inputs = self._chunk_ser(encoded_inputs)
- encoded_inputs = list(map(add_imgge_path, encoded_inputs))
- return encoded_inputs
-
- def _read_encoded_inputs_sample(self, info_str):
- """
- parse label info
- """
- # read text info
- info_dict = json.loads(info_str)
- height = info_dict["height"]
- width = info_dict["width"]
-
- words_list = []
- bbox_list = []
- input_ids_list = []
- token_type_ids_list = []
- gt_label_list = []
-
- if self.contains_re:
- # for re
- entities = []
- relations = []
- id2label = {}
- entity_id_to_index_map = {}
- empty_entity = set()
- for info in info_dict["ocr_info"]:
- if self.contains_re:
- # for re
- if len(info["text"]) == 0:
- empty_entity.add(info["id"])
- continue
- id2label[info["id"]] = info["label"]
- relations.extend([tuple(sorted(l)) for l in info["linking"]])
-
- # x1, y1, x2, y2
- bbox = info["bbox"]
- label = info["label"]
- bbox[0] = int(bbox[0] * 1000.0 / width)
- bbox[2] = int(bbox[2] * 1000.0 / width)
- bbox[1] = int(bbox[1] * 1000.0 / height)
- bbox[3] = int(bbox[3] * 1000.0 / height)
-
- text = info["text"]
- encode_res = self.tokenizer.encode(
- text, pad_to_max_seq_len=False, return_attention_mask=True)
-
- gt_label = []
- if not self.add_special_ids:
- # TODO: use tok.all_special_ids to remove
- encode_res["input_ids"] = encode_res["input_ids"][1:-1]
- encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
- -1]
- encode_res["attention_mask"] = encode_res["attention_mask"][1:
- -1]
- if label.lower() == "other":
- gt_label.extend([0] * len(encode_res["input_ids"]))
- else:
- gt_label.append(self.label2id_map[("b-" + label).upper()])
- gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
- (len(encode_res["input_ids"]) - 1))
- if self.contains_re:
- if gt_label[0] != self.label2id_map["O"]:
- entity_id_to_index_map[info["id"]] = len(entities)
- entities.append({
- "start": len(input_ids_list),
- "end":
- len(input_ids_list) + len(encode_res["input_ids"]),
- "label": label.upper(),
- })
- input_ids_list.extend(encode_res["input_ids"])
- token_type_ids_list.extend(encode_res["token_type_ids"])
- bbox_list.extend([bbox] * len(encode_res["input_ids"]))
- gt_label_list.extend(gt_label)
- words_list.append(text)
-
- encoded_inputs = {
- "input_ids": input_ids_list,
- "labels": gt_label_list,
- "token_type_ids": token_type_ids_list,
- "bbox": bbox_list,
- "attention_mask": [1] * len(input_ids_list),
- # "words_list": words_list,
- }
- encoded_inputs = self.pad_sentences(
- encoded_inputs,
- max_seq_len=self.max_seq_len,
- return_attention_mask=self.return_attention_mask)
- encoded_inputs = self.truncate_inputs(encoded_inputs)
-
- if self.contains_re:
- relations = self._relations(entities, relations, id2label,
- empty_entity, entity_id_to_index_map)
- encoded_inputs['relations'] = relations
- encoded_inputs['entities'] = entities
- return encoded_inputs
-
- def _chunk_ser(self, encoded_inputs):
- encoded_inputs_all = []
- seq_len = len(encoded_inputs['input_ids'])
- chunk_size = 512
- for chunk_id, index in enumerate(range(0, seq_len, chunk_size)):
- chunk_beg = index
- chunk_end = min(index + chunk_size, seq_len)
- encoded_inputs_example = {}
- for key in encoded_inputs:
- encoded_inputs_example[key] = encoded_inputs[key][chunk_beg:
- chunk_end]
-
- encoded_inputs_all.append(encoded_inputs_example)
- return encoded_inputs_all
-
- def _chunk_re(self, encoded_inputs):
- # prepare data
- entities = encoded_inputs.pop('entities')
- relations = encoded_inputs.pop('relations')
- encoded_inputs_all = []
- chunk_size = 512
- for chunk_id, index in enumerate(
- range(0, len(encoded_inputs["input_ids"]), chunk_size)):
- item = {}
- for k in encoded_inputs:
- item[k] = encoded_inputs[k][index:index + chunk_size]
-
- # select entity in current chunk
- entities_in_this_span = []
- global_to_local_map = {} #
- for entity_id, entity in enumerate(entities):
- if (index <= entity["start"] < index + chunk_size and
- index <= entity["end"] < index + chunk_size):
- entity["start"] = entity["start"] - index
- entity["end"] = entity["end"] - index
- global_to_local_map[entity_id] = len(entities_in_this_span)
- entities_in_this_span.append(entity)
-
- # select relations in current chunk
- relations_in_this_span = []
- for relation in relations:
- if (index <= relation["start_index"] < index + chunk_size and
- index <= relation["end_index"] < index + chunk_size):
- relations_in_this_span.append({
- "head": global_to_local_map[relation["head"]],
- "tail": global_to_local_map[relation["tail"]],
- "start_index": relation["start_index"] - index,
- "end_index": relation["end_index"] - index,
- })
- item.update({
- "entities": reformat(entities_in_this_span),
- "relations": reformat(relations_in_this_span),
- })
- item['entities']['label'] = [
- self.entities_labels[x] for x in item['entities']['label']
- ]
- encoded_inputs_all.append(item)
- return encoded_inputs_all
-
- def _relations(self, entities, relations, id2label, empty_entity,
- entity_id_to_index_map):
- """
- build relations
- """
- relations = list(set(relations))
- relations = [
- rel for rel in relations
- if rel[0] not in empty_entity and rel[1] not in empty_entity
- ]
- kv_relations = []
- for rel in relations:
- pair = [id2label[rel[0]], id2label[rel[1]]]
- if pair == ["question", "answer"]:
- kv_relations.append({
- "head": entity_id_to_index_map[rel[0]],
- "tail": entity_id_to_index_map[rel[1]]
- })
- elif pair == ["answer", "question"]:
- kv_relations.append({
- "head": entity_id_to_index_map[rel[1]],
- "tail": entity_id_to_index_map[rel[0]]
- })
- else:
- continue
- relations = sorted(
- [{
- "head": rel["head"],
- "tail": rel["tail"],
- "start_index": get_relation_span(rel, entities)[0],
- "end_index": get_relation_span(rel, entities)[1],
- } for rel in kv_relations],
- key=lambda x: x["head"], )
- return relations
-
- def load_img(self, image_path):
- # read img
- img = cv2.imread(image_path)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- resize_h, resize_w = self.img_size
- im_shape = img.shape[0:2]
- im_scale_y = resize_h / im_shape[0]
- im_scale_x = resize_w / im_shape[1]
- img_new = cv2.resize(
- img, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=2)
- mean = np.array([0.485, 0.456, 0.406])[np.newaxis, np.newaxis, :]
- std = np.array([0.229, 0.224, 0.225])[np.newaxis, np.newaxis, :]
- img_new = img_new / 255.0
- img_new -= mean
- img_new /= std
- img = img_new.transpose((2, 0, 1))
- return img
-
- def __getitem__(self, idx):
- if self.load_mode == "all":
- data = copy.deepcopy(self.encoded_inputs_all[idx])
- else:
- data = self._parse_label_file(self.all_lines[idx])[0]
-
- image_path = data.pop('image_path')
- data["image"] = self.load_img(image_path)
-
- return_data = {}
- for k, v in data.items():
- if k in self.return_keys:
- if self.return_keys[k]['type'] == 'np':
- v = np.array(v, dtype=self.return_keys[k]['dtype'])
- return_data[k] = v
- return return_data
-
- def __len__(self, ):
- if self.load_mode == "all":
- return len(self.encoded_inputs_all)
- else:
- return len(self.all_lines)
-
-
-def get_relation_span(rel, entities):
- bound = []
- for entity_index in [rel["head"], rel["tail"]]:
- bound.append(entities[entity_index]["start"])
- bound.append(entities[entity_index]["end"])
- return min(bound), max(bound)
-
-
-def reformat(data):
- new_data = {}
- for item in data:
- for k, v in item.items():
- if k not in new_data:
- new_data[k] = []
- new_data[k].append(v)
- return new_data
diff --git a/requirements.txt b/requirements.txt
index 9900588b25df99e0853ec4521f0632578c55f530..1d9522aa0167c60ffce263a35b86640efb1438b2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,3 @@ lxml
premailer
openpyxl
fasttext==0.9.1
-
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index 5c67642ef5a29307d112009e18ee4277df216fb0..570c6832e7a6682f634d7ab7538a228256446372 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -239,8 +239,7 @@ fi
if [ ${MODE} = "klquant_whole_infer" ]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
- cd ./train_data/ && tar xf icdar2015_lite.tar
- ln -s ./icdar2015_lite ./icdar2015 && cd ../
+ cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
@@ -249,6 +248,8 @@ if [ ${MODE} = "klquant_whole_infer" ]; then
if [ ${model_name} = "PPOCRv2_ocr_rec_kl" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
+ cd ./train_data/ && tar xf ic15_data.tar && cd ../
cd ./inference && tar xf rec_inference.tar && tar xf ch_PP-OCRv2_rec_infer.tar && cd ../
fi
if [ ${model_name} = "PPOCRv2_ocr_det_kl" ]; then
diff --git a/test_tipc/readme.md b/test_tipc/readme.md
index 8b2489f3445ddfa87c1e587d6da81992fdb90e64..7b7548cd7296760d4caec0ed741c47137d86ece1 100644
--- a/test_tipc/readme.md
+++ b/test_tipc/readme.md
@@ -68,14 +68,14 @@ test_tipc/
├── model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt # 测试Linux上c++预测的配置文件
├── model_linux_gpu_normal_normal_infer_python_jetson.txt # 测试Jetson上python预测的配置文件
├── train_linux_gpu_fleet_amp_infer_python_linux_gpu_cpu.txt # 测试Linux上多机多卡、混合精度训练和python预测的配置文件
- ├── ...
+ ├── ...
├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录
- ├── ...
+ ├── ...
├── ch_ppocr_mobile_v2.0_rec # ch_ppocr_mobile_v2.0_rec模型的测试配置文件目录
- ├── ...
+ ├── ...
├── ch_ppocr_server_v2.0_det # ch_ppocr_server_v2.0_det模型的测试配置文件目录
- ├── ...
- ├── ...
+ ├── ...
+ ├── ...
├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对
├── python_ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型python预测fp32精度的结果
├── python_ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型python预测fp16精度的结果
@@ -119,7 +119,7 @@ bash test_tipc/test_train_inference_python.sh configs/[model_name]/[params_file_
bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
# 运行测试
bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
-```
+```
关于本示例命令的更多信息可查看[基础训练预测使用文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/test_tipc/docs/test_train_inference_python.md#22-%E5%8A%9F%E8%83%BD%E6%B5%8B%E8%AF%95)。
### 配置文件命名规范
@@ -136,9 +136,9 @@ bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobil
## 4. 开始测试
-各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
-- [test_train_inference_python 使用](docs/test_train_inference_python.md) :测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
+各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
+- [test_train_inference_python 使用](docs/test_train_inference_python.md) :测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
- [test_inference_cpp 使用](docs/test_inference_cpp.md):测试基于C++的模型推理。
- [test_serving 使用](docs/test_serving.md):测试基于Paddle Serving的服务化部署功能。
-- [test_lite_arm_cpu_cpp 使用](docs/test_lite_arm_cpu_cpp.md):测试基于Paddle-Lite的ARM CPU端c++预测部署功能。
+- [test_lite_arm_cpp 使用](docs/test_lite_arm_cpp.md):测试基于Paddle-Lite的ARM CPU端c++预测部署功能。
- [test_paddle2onnx 使用](docs/test_paddle2onnx.md):测试Paddle2ONNX的模型转化功能,并验证正确性。
diff --git a/test_tipc/supplementary/__init__.py b/test_tipc/supplementary/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/test_tipc/supplementary/__init__.py
@@ -0,0 +1 @@
+
diff --git a/test_tipc/supplementary/config.py b/test_tipc/supplementary/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0dce227ef1f1a57780b36cb7f9f60acfe6afc36
--- /dev/null
+++ b/test_tipc/supplementary/config.py
@@ -0,0 +1,137 @@
+import numpy as np
+import os
+import sys
+import platform
+import yaml
+import time
+import shutil
+import paddle
+import paddle.distributed as dist
+from tqdm import tqdm
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+from utils import get_logger, print_dict
+
+
+class ArgsParser(ArgumentParser):
+ def __init__(self):
+ super(ArgsParser, self).__init__(
+ formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument(
+ "-o", "--opt", nargs='+', help="set configuration options")
+ self.add_argument(
+ '-p',
+ '--profiler_options',
+ type=str,
+ default=None,
+ help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
+ )
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, \
+ "Please specify --config=configure_file_path."
+ args.opt = self._parse_opt(args.opt)
+ return args
+
+ def _parse_opt(self, opts):
+ config = {}
+ if not opts:
+ return config
+ for s in opts:
+ s = s.strip()
+ k, v = s.split('=')
+ config[k] = yaml.load(v, Loader=yaml.Loader)
+ return config
+
+
+class AttrDict(dict):
+ """Single level attribute dict, NOT recursive"""
+
+ def __init__(self, **kwargs):
+ super(AttrDict, self).__init__()
+ super(AttrDict, self).update(kwargs)
+
+ def __getattr__(self, key):
+ if key in self:
+ return self[key]
+ raise AttributeError("object has no attribute '{}'".format(key))
+
+
+global_config = AttrDict()
+
+default_config = {'Global': {'debug': False, }}
+
+
+def load_config(file_path):
+ """
+ Load config from yml/yaml file.
+ Args:
+ file_path (str): Path of the config file to be loaded.
+ Returns: global config
+ """
+ merge_config(default_config)
+ _, ext = os.path.splitext(file_path)
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
+ merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
+ return global_config
+
+
+def merge_config(config):
+ """
+ Merge config into global config.
+ Args:
+ config (dict): Config to be merged.
+ Returns: global config
+ """
+ for key, value in config.items():
+ if "." not in key:
+ if isinstance(value, dict) and key in global_config:
+ global_config[key].update(value)
+ else:
+ global_config[key] = value
+ else:
+ sub_keys = key.split('.')
+ assert (
+ sub_keys[0] in global_config
+ ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
+ global_config.keys(), sub_keys[0])
+ cur = global_config[sub_keys[0]]
+ for idx, sub_key in enumerate(sub_keys[1:]):
+ if idx == len(sub_keys) - 2:
+ cur[sub_key] = value
+ else:
+ cur = cur[sub_key]
+
+
+def preprocess(is_train=False):
+ FLAGS = ArgsParser().parse_args()
+ profiler_options = FLAGS.profiler_options
+ config = load_config(FLAGS.config)
+ merge_config(FLAGS.opt)
+ profile_dic = {"profiler_options": FLAGS.profiler_options}
+ merge_config(profile_dic)
+
+ if is_train:
+ # save_config
+ save_model_dir = config['save_model_dir']
+ os.makedirs(save_model_dir, exist_ok=True)
+ with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
+ yaml.dump(
+ dict(config), f, default_flow_style=False, sort_keys=False)
+ log_file = '{}/train.log'.format(save_model_dir)
+ else:
+ log_file = None
+ logger = get_logger(name='root', log_file=log_file)
+
+ # check if set use_gpu=True in paddlepaddle cpu version
+ use_gpu = config['use_gpu']
+
+ print_dict(config, logger)
+
+ return config, logger
+
+
+if __name__ == "__main__":
+ config, logger = preprocess(is_train=False)
+ # print(config)
diff --git a/test_tipc/supplementary/custom_op/custom_relu_op.cc b/test_tipc/supplementary/custom_op/custom_relu_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..97002a9118e867588065bf28c5695f53b1d42694
--- /dev/null
+++ b/test_tipc/supplementary/custom_op/custom_relu_op.cc
@@ -0,0 +1,109 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+// reference from : https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cc
+#include
+#include
+
+#include "paddle/extension.h"
+
+template
+void relu_cpu_forward_kernel(const data_t* x_data,
+ data_t* out_data,
+ int64_t x_numel) {
+ for (int i = 0; i < x_numel; ++i) {
+ out_data[i] = std::max(static_cast(0.), x_data[i]);
+ }
+}
+
+template
+void relu_cpu_backward_kernel(const data_t* grad_out_data,
+ const data_t* out_data,
+ data_t* grad_x_data,
+ int64_t out_numel) {
+ for (int i = 0; i < out_numel; ++i) {
+ grad_x_data[i] =
+ grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.);
+ }
+}
+
+std::vector relu_cpu_forward(const paddle::Tensor& x) {
+ auto out = paddle::Tensor(paddle::PlaceType::kCPU);
+
+ out.reshape(x.shape());
+ PD_DISPATCH_FLOATING_TYPES(
+ x.type(), "relu_cpu_forward", ([&] {
+ relu_cpu_forward_kernel(
+ x.data(), out.mutable_data(x.place()), x.size());
+ }));
+
+ return {out};
+}
+
+std::vector relu_cpu_backward(const paddle::Tensor& x,
+ const paddle::Tensor& out,
+ const paddle::Tensor& grad_out) {
+ auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
+ grad_x.reshape(x.shape());
+
+ PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
+ relu_cpu_backward_kernel(
+ grad_out.data(),
+ out.data(),
+ grad_x.mutable_data(x.place()),
+ out.size());
+ }));
+
+ return {grad_x};
+}
+
+std::vector relu_cuda_forward(const paddle::Tensor& x);
+std::vector relu_cuda_backward(const paddle::Tensor& x,
+ const paddle::Tensor& out,
+ const paddle::Tensor& grad_out);
+
+std::vector ReluForward(const paddle::Tensor& x) {
+ // TODO(chenweihang): Check Input
+ if (x.place() == paddle::PlaceType::kCPU) {
+ return relu_cpu_forward(x);
+ } else if (x.place() == paddle::PlaceType::kGPU) {
+ return relu_cuda_forward(x);
+ } else {
+ throw std::runtime_error("Not implemented.");
+ }
+}
+
+std::vector ReluBackward(const paddle::Tensor& x,
+ const paddle::Tensor& out,
+ const paddle::Tensor& grad_out) {
+ // TODO(chenweihang): Check Input
+ if (x.place() == paddle::PlaceType::kCPU) {
+ return relu_cpu_backward(x, out, grad_out);
+ } else if (x.place() == paddle::PlaceType::kGPU) {
+ return relu_cuda_backward(x, out, grad_out);
+ } else {
+ throw std::runtime_error("Not implemented.");
+ }
+}
+
+PD_BUILD_OP(custom_relu)
+ .Inputs({"X"})
+ .Outputs({"Out"})
+ .SetKernelFn(PD_KERNEL(ReluForward));
+
+PD_BUILD_GRAD_OP(custom_relu)
+ .Inputs({"X", "Out", paddle::Grad("Out")})
+ .Outputs({paddle::Grad("X")})
+ .SetKernelFn(PD_KERNEL(ReluBackward));
\ No newline at end of file
diff --git a/test_tipc/supplementary/custom_op/custom_relu_op.cu b/test_tipc/supplementary/custom_op/custom_relu_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9b953a33cc73bd9f9ff5086551e4243f580f084c
--- /dev/null
+++ b/test_tipc/supplementary/custom_op/custom_relu_op.cu
@@ -0,0 +1,76 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+// reference https://github.com/PaddlePaddle/Paddle-Inference-Demo/blob/master/python/custom-operator/custom_relu_op.cu
+
+#include "paddle/extension.h"
+
+template
+__global__ void relu_cuda_forward_kernel(const data_t* x,
+ data_t* y,
+ const int num) {
+ int gid = blockIdx.x * blockDim.x + threadIdx.x;
+ for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
+ y[i] = max(x[i], static_cast(0.));
+ }
+}
+
+template
+__global__ void relu_cuda_backward_kernel(const data_t* dy,
+ const data_t* y,
+ data_t* dx,
+ const int num) {
+ int gid = blockIdx.x * blockDim.x + threadIdx.x;
+ for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
+ dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
+ }
+}
+
+std::vector relu_cuda_forward(const paddle::Tensor& x) {
+ auto out = paddle::Tensor(paddle::PlaceType::kGPU);
+
+ out.reshape(x.shape());
+ int numel = x.size();
+ int block = 512;
+ int grid = (numel + block - 1) / block;
+ PD_DISPATCH_FLOATING_TYPES(
+ x.type(), "relu_cuda_forward_kernel", ([&] {
+ relu_cuda_forward_kernel<<>>(
+ x.data(), out.mutable_data(x.place()), numel);
+ }));
+
+ return {out};
+}
+
+std::vector relu_cuda_backward(const paddle::Tensor& x,
+ const paddle::Tensor& out,
+ const paddle::Tensor& grad_out) {
+ auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
+ grad_x.reshape(x.shape());
+
+ int numel = out.size();
+ int block = 512;
+ int grid = (numel + block - 1) / block;
+ PD_DISPATCH_FLOATING_TYPES(
+ out.type(), "relu_cuda_backward_kernel", ([&] {
+ relu_cuda_backward_kernel<<>>(
+ grad_out.data(),
+ out.data(),
+ grad_x.mutable_data(x.place()),
+ numel);
+ }));
+
+ return {grad_x};
+}
diff --git a/test_tipc/supplementary/custom_op/test.py b/test_tipc/supplementary/custom_op/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7f303dd65d52c2e4332fdd2c77dd15a057101b
--- /dev/null
+++ b/test_tipc/supplementary/custom_op/test.py
@@ -0,0 +1,76 @@
+import paddle
+import paddle.nn as nn
+from paddle.vision.transforms import Compose, Normalize
+from paddle.utils.cpp_extension import load
+from paddle.inference import Config
+from paddle.inference import create_predictor
+import numpy as np
+
+EPOCH_NUM = 4
+BATCH_SIZE = 64
+
+# jit compile custom op
+custom_ops = load(
+ name="custom_jit_ops", sources=["custom_relu_op.cc", "custom_relu_op.cu"])
+
+
+class LeNet(nn.Layer):
+ def __init__(self):
+ super(LeNet, self).__init__()
+ self.conv1 = nn.Conv2D(
+ in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
+ self.max_pool1 = nn.MaxPool2D(kernel_size=2, stride=2)
+ self.conv2 = nn.Conv2D(
+ in_channels=6, out_channels=16, kernel_size=5, stride=1)
+ self.max_pool2 = nn.MaxPool2D(kernel_size=2, stride=2)
+ self.linear1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)
+ self.linear2 = nn.Linear(in_features=120, out_features=84)
+ self.linear3 = nn.Linear(in_features=84, out_features=10)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = custom_ops.custom_relu(x)
+ x = self.max_pool1(x)
+ x = custom_ops.custom_relu(x)
+ x = self.conv2(x)
+ x = self.max_pool2(x)
+ x = paddle.flatten(x, start_axis=1, stop_axis=-1)
+ x = self.linear1(x)
+ x = custom_ops.custom_relu(x)
+ x = self.linear2(x)
+ x = custom_ops.custom_relu(x)
+ x = self.linear3(x)
+ return x
+
+
+# set device
+paddle.set_device("gpu")
+
+# model
+net = LeNet()
+loss_fn = nn.CrossEntropyLoss()
+opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=net.parameters())
+
+# data loader
+transform = Compose([Normalize(mean=[127.5], std=[127.5], data_format='CHW')])
+train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
+train_loader = paddle.io.DataLoader(
+ train_dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ num_workers=2)
+
+# train
+for epoch_id in range(EPOCH_NUM):
+ for batch_id, (image, label) in enumerate(train_loader()):
+ out = net(image)
+ loss = loss_fn(out, label)
+ loss.backward()
+
+ if batch_id % 300 == 0:
+ print("Epoch {} batch {}: loss = {}".format(epoch_id, batch_id,
+ np.mean(loss.numpy())))
+
+ opt.step()
+ opt.clear_grad()
diff --git a/test_tipc/supplementary/data.py b/test_tipc/supplementary/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..2770a9a42c745d52e8310abd61356bf92e50f436
--- /dev/null
+++ b/test_tipc/supplementary/data.py
@@ -0,0 +1,140 @@
+import numpy as np
+import paddle
+import os
+import cv2
+import glob
+
+
+def transform(data, ops=None):
+ """ transform """
+ if ops is None:
+ ops = []
+ for op in ops:
+ data = op(data)
+ if data is None:
+ return None
+ return data
+
+
+def create_operators(op_param_list, global_config=None):
+ """
+ create operators based on the config
+ Args:
+ params(list): a dict list, used to create some operators
+ """
+ assert isinstance(op_param_list, list), ('operator config should be a list')
+ ops = []
+ for operator in op_param_list:
+ assert isinstance(operator,
+ dict) and len(operator) == 1, "yaml format error"
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+ if global_config is not None:
+ param.update(global_config)
+ op = eval(op_name)(**param)
+ ops.append(op)
+ return ops
+
+
+class DecodeImage(object):
+ """ decode image """
+
+ def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+
+ def __call__(self, data):
+ img = data['image']
+ if six.PY2:
+ assert type(img) is str and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ else:
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+ img = cv2.imdecode(img, 1)
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ img = img[:, :, ::-1]
+
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+
+ data['image'] = img
+ data['src_image'] = img
+ return data
+
+
+class NormalizeImage(object):
+ """ normalize image such as substract mean, divide std
+ """
+
+ def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype('float32')
+ self.std = np.array(std).reshape(shape).astype('float32')
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ assert isinstance(img,
+ np.ndarray), "invalid input 'img' in NormalizeImage"
+ data['image'] = (
+ img.astype('float32') * self.scale - self.mean) / self.std
+ return data
+
+
+class ToCHWImage(object):
+ """ convert hwc image to chw image
+ """
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = data['image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+ data['image'] = img.transpose((2, 0, 1))
+
+ src_img = data['src_image']
+ from PIL import Image
+ if isinstance(img, Image.Image):
+ src_img = np.array(src_img)
+ data['src_image'] = img.transpose((2, 0, 1))
+
+ return data
+
+
+class SimpleDataset(nn.Dataset):
+ def __init__(self, config, mode, logger, seed=None):
+ self.logger = logger
+ self.mode = mode.lower()
+
+ data_dir = config['Train']['data_dir']
+
+ imgs_list = self.get_image_list(data_dir)
+
+ self.ops = create_operators(cfg['transforms'], None)
+
+ def get_image_list(self, img_dir):
+ imgs = glob.glob(os.path.join(img_dir, "*.png"))
+ if len(imgs) == 0:
+ raise ValueError(f"not any images founded in {img_dir}")
+ return imgs
+
+ def __getitem__(self, idx):
+ return None
diff --git a/test_tipc/supplementary/data_loader.py b/test_tipc/supplementary/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..049e7b2d36306d4bb7264d1c45a072ed84bbba60
--- /dev/null
+++ b/test_tipc/supplementary/data_loader.py
@@ -0,0 +1,66 @@
+import numpy as np
+from paddle.vision.datasets import Cifar100
+from paddle.vision.transforms import Normalize
+from paddle.fluid.dataloader.collate import default_collate_fn
+import signal
+import os
+from paddle.io import Dataset, DataLoader, DistributedBatchSampler
+
+
+def term_mp(sig_num, frame):
+ """ kill all child processes
+ """
+ pid = os.getpid()
+ pgid = os.getpgid(os.getpid())
+ print("main proc {} exit, kill process group " "{}".format(pid, pgid))
+ os.killpg(pgid, signal.SIGKILL)
+ return
+
+
+def build_dataloader(mode,
+ batch_size=4,
+ seed=None,
+ num_workers=0,
+ device='gpu:0'):
+
+ normalize = Normalize(
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='HWC')
+
+ if mode.lower() == "train":
+ dataset = Cifar100(mode=mode, transform=normalize)
+ elif mode.lower() in ["test", 'valid', 'eval']:
+ dataset = Cifar100(mode="test", transform=normalize)
+ else:
+ raise ValueError(f"{mode} should be one of ['train', 'test']")
+
+ # define batch sampler
+ batch_sampler = DistributedBatchSampler(
+ dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ num_workers=num_workers,
+ return_list=True,
+ use_shared_memory=False)
+
+ # support exit using ctrl+c
+ signal.signal(signal.SIGINT, term_mp)
+ signal.signal(signal.SIGTERM, term_mp)
+
+ return data_loader
+
+
+# cifar100 = Cifar100(mode='train', transform=normalize)
+
+# data = cifar100[0]
+
+# image, label = data
+
+# reader = build_dataloader('train')
+
+# for idx, data in enumerate(reader):
+# print(idx, data[0].shape, data[1].shape)
+# if idx >= 10:
+# break
diff --git a/test_tipc/supplementary/load_cifar.py b/test_tipc/supplementary/load_cifar.py
new file mode 100644
index 0000000000000000000000000000000000000000..6646dca390dd9e0bde51431f474008d07e638a01
--- /dev/null
+++ b/test_tipc/supplementary/load_cifar.py
@@ -0,0 +1,40 @@
+import pickle as p
+import numpy as np
+from PIL import Image
+
+
+def load_CIFAR_batch(filename):
+ """ load single batch of cifar """
+ with open(filename, 'rb') as f:
+ datadict = p.load(f, encoding='bytes')
+ # 以字典的形式取出数据
+ X = datadict[b'data']
+ Y = datadict[b'fine_labels']
+ try:
+ X = X.reshape(10000, 3, 32, 32)
+ except:
+ X = X.reshape(50000, 3, 32, 32)
+ Y = np.array(Y)
+ print(Y.shape)
+ return X, Y
+
+
+if __name__ == "__main__":
+ mode = "train"
+ imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
+ with open(f'./cifar-100-python/{mode}_imgs/img_label.txt', 'a+') as f:
+ for i in range(imgY.shape[0]):
+ f.write('img' + str(i) + ' ' + str(imgY[i]) + '\n')
+
+ for i in range(imgX.shape[0]):
+ imgs = imgX[i]
+ img0 = imgs[0]
+ img1 = imgs[1]
+ img2 = imgs[2]
+ i0 = Image.fromarray(img0)
+ i1 = Image.fromarray(img1)
+ i2 = Image.fromarray(img2)
+ img = Image.merge("RGB", (i0, i1, i2))
+ name = "img" + str(i) + ".png"
+ img.save(f"./cifar-100-python/{mode}_imgs/" + name, "png")
+ print("save successfully!")
diff --git a/test_tipc/supplementary/loss.py b/test_tipc/supplementary/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cb1cd498c9b02be85975bbba4197b7dc2ef310e
--- /dev/null
+++ b/test_tipc/supplementary/loss.py
@@ -0,0 +1,128 @@
+import paddle
+import paddle.nn.functional as F
+
+
+class Loss(object):
+ """
+ Loss
+ """
+
+ def __init__(self, class_dim=1000, epsilon=None):
+ assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
+ self._class_dim = class_dim
+ if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
+ self._epsilon = epsilon
+ self._label_smoothing = True
+ else:
+ self._epsilon = None
+ self._label_smoothing = False
+
+ def _labelsmoothing(self, target):
+ if target.shape[-1] != self._class_dim:
+ one_hot_target = F.one_hot(target, self._class_dim)
+ else:
+ one_hot_target = target
+ soft_target = F.label_smooth(one_hot_target, epsilon=self._epsilon)
+ soft_target = paddle.reshape(soft_target, shape=[-1, self._class_dim])
+ return soft_target
+
+ def _crossentropy(self, input, target, use_pure_fp16=False):
+ if self._label_smoothing:
+ target = self._labelsmoothing(target)
+ input = -F.log_softmax(input, axis=-1)
+ cost = paddle.sum(target * input, axis=-1)
+ else:
+ cost = F.cross_entropy(input=input, label=target)
+ if use_pure_fp16:
+ avg_cost = paddle.sum(cost)
+ else:
+ avg_cost = paddle.mean(cost)
+ return avg_cost
+
+ def __call__(self, input, target):
+ return self._crossentropy(input, target)
+
+
+def build_loss(config, epsilon=None):
+ class_dim = config['class_dim']
+ loss_func = Loss(class_dim=class_dim, epsilon=epsilon)
+ return loss_func
+
+
+class LossDistill(Loss):
+ def __init__(self, model_name_list, class_dim=1000, epsilon=None):
+ assert class_dim > 1, "class_dim=%d is not larger than 1" % (class_dim)
+ self._class_dim = class_dim
+ if epsilon is not None and epsilon >= 0.0 and epsilon <= 1.0:
+ self._epsilon = epsilon
+ self._label_smoothing = True
+ else:
+ self._epsilon = None
+ self._label_smoothing = False
+
+ self.model_name_list = model_name_list
+ assert len(self.model_name_list) > 1, "error"
+
+ def __call__(self, input, target):
+ losses = {}
+ for k in self.model_name_list:
+ inp = input[k]
+ losses[k] = self._crossentropy(inp, target)
+ return losses
+
+
+class KLJSLoss(object):
+ def __init__(self, mode='kl'):
+ assert mode in ['kl', 'js', 'KL', 'JS'
+ ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ self.mode = mode
+
+ def __call__(self, p1, p2, reduction="mean"):
+ p1 = F.softmax(p1, axis=-1)
+ p2 = F.softmax(p2, axis=-1)
+
+ loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+
+ if self.mode.lower() == "js":
+ loss += paddle.multiply(
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ loss *= 0.5
+ if reduction == "mean":
+ loss = paddle.mean(loss)
+ elif reduction == "none" or reduction is None:
+ return loss
+ else:
+ loss = paddle.sum(loss)
+ return loss
+
+
+class DMLLoss(object):
+ def __init__(self, model_name_pairs, mode='js'):
+
+ self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
+ self.kljs_loss = KLJSLoss(mode=mode)
+
+ def _check_model_name_pairs(self, model_name_pairs):
+ if not isinstance(model_name_pairs, list):
+ return []
+ elif isinstance(model_name_pairs[0], list) and isinstance(
+ model_name_pairs[0][0], str):
+ return model_name_pairs
+ else:
+ return [model_name_pairs]
+
+ def __call__(self, predicts, target=None):
+ loss_dict = dict()
+ for pairs in self.model_name_pairs:
+ p1 = predicts[pairs[0]]
+ p2 = predicts[pairs[1]]
+
+ loss_dict[pairs[0] + "_" + pairs[1]] = self.kljs_loss(p1, p2)
+
+ return loss_dict
+
+
+# def build_distill_loss(config, epsilon=None):
+# class_dim = config['class_dim']
+# loss = LossDistill(model_name_list=['student', 'student1'], )
+# return loss_func
diff --git a/test_tipc/supplementary/metric.py b/test_tipc/supplementary/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..401cf9b9d22595e20be16314e763e602fd411b70
--- /dev/null
+++ b/test_tipc/supplementary/metric.py
@@ -0,0 +1,56 @@
+import paddle
+import paddle.nn.functional as F
+from collections import OrderedDict
+
+
+def create_metric(out,
+ label,
+ architecture=None,
+ topk=5,
+ classes_num=1000,
+ use_distillation=False,
+ mode="train"):
+ """
+ Create measures of model accuracy, such as top1 and top5
+
+ Args:
+ out(variable): model output variable
+ feeds(dict): dict of model input variables(included label)
+ topk(int): usually top5
+ classes_num(int): num of classes
+ use_distillation(bool): whether to use distillation training
+ mode(str): mode, train/valid
+
+ Returns:
+ fetchs(dict): dict of measures
+ """
+ # if architecture["name"] == "GoogLeNet":
+ # assert len(out) == 3, "GoogLeNet should have 3 outputs"
+ # out = out[0]
+ # else:
+ # # just need student label to get metrics
+ # if use_distillation:
+ # out = out[1]
+ softmax_out = F.softmax(out)
+
+ fetchs = OrderedDict()
+ # set top1 to fetchs
+ top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
+ # set topk to fetchs
+ k = min(topk, classes_num)
+ topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
+
+ # multi cards' eval
+ if mode != "train" and paddle.distributed.get_world_size() > 1:
+ top1 = paddle.distributed.all_reduce(
+ top1, op=paddle.distributed.ReduceOp.
+ SUM) / paddle.distributed.get_world_size()
+ topk = paddle.distributed.all_reduce(
+ topk, op=paddle.distributed.ReduceOp.
+ SUM) / paddle.distributed.get_world_size()
+
+ fetchs['top1'] = top1
+ topk_name = 'top{}'.format(k)
+ fetchs[topk_name] = topk
+
+ return fetchs
diff --git a/test_tipc/supplementary/mv3.py b/test_tipc/supplementary/mv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ffcedac03857961d3c0136c3d2d26e0b5feca6d
--- /dev/null
+++ b/test_tipc/supplementary/mv3.py
@@ -0,0 +1,487 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.functional import hardswish, hardsigmoid
+from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
+from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
+from paddle.regularizer import L2Decay
+import math
+
+from paddle.utils.cpp_extension import load
+# jit compile custom op
+custom_ops = load(
+ name="custom_jit_ops",
+ sources=["./custom_op/custom_relu_op.cc", "./custom_op/custom_relu_op.cu"])
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class MobileNetV3(nn.Layer):
+ def __init__(self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ use_custom_relu=False):
+ super(MobileNetV3, self).__init__()
+ self.use_custom_relu = use_custom_relu
+
+ inplanes = 16
+ if model_name == "large":
+ self.cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, False, "relu", 1],
+ [3, 64, 24, False, "relu", 2],
+ [3, 72, 24, False, "relu", 1],
+ [5, 72, 40, True, "relu", 2],
+ [5, 120, 40, True, "relu", 1],
+ [5, 120, 40, True, "relu", 1],
+ [3, 240, 80, False, "hardswish", 2],
+ [3, 200, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 184, 80, False, "hardswish", 1],
+ [3, 480, 112, True, "hardswish", 1],
+ [3, 672, 112, True, "hardswish", 1],
+ [5, 672, 160, True, "hardswish", 2],
+ [5, 960, 160, True, "hardswish", 1],
+ [5, 960, 160, True, "hardswish", 1],
+ ]
+ self.cls_ch_squeeze = 960
+ self.cls_ch_expand = 1280
+ elif model_name == "small":
+ self.cfg = [
+ # k, exp, c, se, nl, s,
+ [3, 16, 16, True, "relu", 2],
+ [3, 72, 24, False, "relu", 2],
+ [3, 88, 24, False, "relu", 1],
+ [5, 96, 40, True, "hardswish", 2],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 240, 40, True, "hardswish", 1],
+ [5, 120, 48, True, "hardswish", 1],
+ [5, 144, 48, True, "hardswish", 1],
+ [5, 288, 96, True, "hardswish", 2],
+ [5, 576, 96, True, "hardswish", 1],
+ [5, 576, 96, True, "hardswish", 1],
+ ]
+ self.cls_ch_squeeze = 576
+ self.cls_ch_expand = 1280
+ else:
+ raise NotImplementedError(
+ "mode[{}_model] is not implemented!".format(model_name))
+
+ self.conv1 = ConvBNLayer(
+ in_c=3,
+ out_c=make_divisible(inplanes * scale),
+ filter_size=3,
+ stride=2,
+ padding=1,
+ num_groups=1,
+ if_act=True,
+ act="hardswish",
+ name="conv1",
+ use_custom_relu=self.use_custom_relu)
+
+ self.block_list = []
+ i = 0
+ inplanes = make_divisible(inplanes * scale)
+ for (k, exp, c, se, nl, s) in self.cfg:
+ block = self.add_sublayer(
+ "conv" + str(i + 2),
+ ResidualUnit(
+ in_c=inplanes,
+ mid_c=make_divisible(scale * exp),
+ out_c=make_divisible(scale * c),
+ filter_size=k,
+ stride=s,
+ use_se=se,
+ act=nl,
+ name="conv" + str(i + 2),
+ use_custom_relu=self.use_custom_relu))
+ self.block_list.append(block)
+ inplanes = make_divisible(scale * c)
+ i += 1
+
+ self.last_second_conv = ConvBNLayer(
+ in_c=inplanes,
+ out_c=make_divisible(scale * self.cls_ch_squeeze),
+ filter_size=1,
+ stride=1,
+ padding=0,
+ num_groups=1,
+ if_act=True,
+ act="hardswish",
+ name="conv_last",
+ use_custom_relu=self.use_custom_relu)
+
+ self.pool = AdaptiveAvgPool2D(1)
+
+ self.last_conv = Conv2D(
+ in_channels=make_divisible(scale * self.cls_ch_squeeze),
+ out_channels=self.cls_ch_expand,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=False)
+
+ self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
+
+ self.out = Linear(
+ self.cls_ch_expand,
+ class_dim,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+
+ def forward(self, inputs, label=None):
+ x = self.conv1(inputs)
+
+ for block in self.block_list:
+ x = block(x)
+
+ x = self.last_second_conv(x)
+ x = self.pool(x)
+
+ x = self.last_conv(x)
+ x = hardswish(x)
+ x = self.dropout(x)
+ x = paddle.flatten(x, start_axis=1, stop_axis=-1)
+ x = self.out(x)
+ return x
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_c,
+ out_c,
+ filter_size,
+ stride,
+ padding,
+ num_groups=1,
+ if_act=True,
+ act=None,
+ use_cudnn=True,
+ name="",
+ use_custom_relu=False):
+ super(ConvBNLayer, self).__init__()
+ self.if_act = if_act
+ self.act = act
+ self.conv = Conv2D(
+ in_channels=in_c,
+ out_channels=out_c,
+ kernel_size=filter_size,
+ stride=stride,
+ padding=padding,
+ groups=num_groups,
+ weight_attr=ParamAttr(),
+ bias_attr=False)
+ self.bn = BatchNorm(
+ num_channels=out_c,
+ act=None,
+ param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ # moving_mean_name=name + "_bn_mean",
+ # moving_variance_name=name + "_bn_variance")
+
+ self.use_custom_relu = use_custom_relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.if_act:
+ if self.act == "relu":
+ if self.use_custom_relu:
+ x = custom_ops.custom_relu(x)
+ else:
+ x = F.relu(x)
+ elif self.act == "hardswish":
+ x = hardswish(x)
+ else:
+ print("The activation function is selected incorrectly.")
+ exit()
+ return x
+
+
+class ResidualUnit(nn.Layer):
+ def __init__(self,
+ in_c,
+ mid_c,
+ out_c,
+ filter_size,
+ stride,
+ use_se,
+ act=None,
+ name='',
+ use_custom_relu=False):
+ super(ResidualUnit, self).__init__()
+ self.if_shortcut = stride == 1 and in_c == out_c
+ self.if_se = use_se
+
+ self.use_custom_relu = use_custom_relu
+
+ self.expand_conv = ConvBNLayer(
+ in_c=in_c,
+ out_c=mid_c,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ if_act=True,
+ act=act,
+ name=name + "_expand",
+ use_custom_relu=self.use_custom_relu)
+ self.bottleneck_conv = ConvBNLayer(
+ in_c=mid_c,
+ out_c=mid_c,
+ filter_size=filter_size,
+ stride=stride,
+ padding=int((filter_size - 1) // 2),
+ num_groups=mid_c,
+ if_act=True,
+ act=act,
+ name=name + "_depthwise",
+ use_custom_relu=self.use_custom_relu)
+ if self.if_se:
+ self.mid_se = SEModule(mid_c, name=name + "_se")
+ self.linear_conv = ConvBNLayer(
+ in_c=mid_c,
+ out_c=out_c,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ if_act=False,
+ act=None,
+ name=name + "_linear",
+ use_custom_relu=self.use_custom_relu)
+
+ def forward(self, inputs):
+ x = self.expand_conv(inputs)
+ x = self.bottleneck_conv(x)
+ if self.if_se:
+ x = self.mid_se(x)
+ x = self.linear_conv(x)
+ if self.if_shortcut:
+ x = paddle.add(inputs, x)
+ return x
+
+
+class SEModule(nn.Layer):
+ def __init__(self, channel, reduction=4, name=""):
+ super(SEModule, self).__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ weight_attr=ParamAttr(),
+ bias_attr=ParamAttr())
+
+ def forward(self, inputs):
+ outputs = self.avg_pool(inputs)
+ outputs = self.conv1(outputs)
+ outputs = F.relu(outputs)
+ outputs = self.conv2(outputs)
+ outputs = hardsigmoid(outputs, slope=0.2, offset=0.5)
+ return paddle.multiply(x=inputs, y=outputs)
+
+
+def MobileNetV3_small_x0_35(**args):
+ model = MobileNetV3(model_name="small", scale=0.35, **args)
+ return model
+
+
+def MobileNetV3_small_x0_5(**args):
+ model = MobileNetV3(model_name="small", scale=0.5, **args)
+ return model
+
+
+def MobileNetV3_small_x0_75(**args):
+ model = MobileNetV3(model_name="small", scale=0.75, **args)
+ return model
+
+
+def MobileNetV3_small_x1_0(**args):
+ model = MobileNetV3(model_name="small", scale=1.0, **args)
+ return model
+
+
+def MobileNetV3_small_x1_25(**args):
+ model = MobileNetV3(model_name="small", scale=1.25, **args)
+ return model
+
+
+def MobileNetV3_large_x0_35(**args):
+ model = MobileNetV3(model_name="large", scale=0.35, **args)
+ return model
+
+
+def MobileNetV3_large_x0_5(**args):
+ model = MobileNetV3(model_name="large", scale=0.5, **args)
+ return model
+
+
+def MobileNetV3_large_x0_75(**args):
+ model = MobileNetV3(model_name="large", scale=0.75, **args)
+ return model
+
+
+def MobileNetV3_large_x1_0(**args):
+ model = MobileNetV3(model_name="large", scale=1.0, **args)
+ return model
+
+
+def MobileNetV3_large_x1_25(**args):
+ model = MobileNetV3(model_name="large", scale=1.25, **args)
+ return
+
+
+class DistillMV3(nn.Layer):
+ def __init__(self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ args=None,
+ use_custom_relu=False):
+ super(DistillMV3, self).__init__()
+
+ self.student = MobileNetV3(
+ model_name=model_name,
+ scale=scale,
+ class_dim=class_dim,
+ use_custom_relu=use_custom_relu)
+
+ self.student1 = MobileNetV3(
+ model_name=model_name,
+ scale=scale,
+ class_dim=class_dim,
+ use_custom_relu=use_custom_relu)
+
+ def forward(self, inputs, label=None):
+ predicts = dict()
+ predicts['student'] = self.student(inputs, label)
+ predicts['student1'] = self.student1(inputs, label)
+ return predicts
+
+
+def distillmv3_large_x0_5(**args):
+ model = DistillMV3(model_name="large", scale=0.5, **args)
+ return model
+
+
+class SiameseMV3(nn.Layer):
+ def __init__(self,
+ scale=1.0,
+ model_name="small",
+ dropout_prob=0.2,
+ class_dim=1000,
+ args=None,
+ use_custom_relu=False):
+ super(SiameseMV3, self).__init__()
+
+ self.net = MobileNetV3(
+ model_name=model_name,
+ scale=scale,
+ class_dim=class_dim,
+ use_custom_relu=use_custom_relu)
+ self.net1 = MobileNetV3(
+ model_name=model_name,
+ scale=scale,
+ class_dim=class_dim,
+ use_custom_relu=use_custom_relu)
+
+ def forward(self, inputs, label=None):
+ # net
+ x = self.net.conv1(inputs)
+ for block in self.net.block_list:
+ x = block(x)
+
+ # net1
+ x1 = self.net1.conv1(inputs)
+ for block in self.net1.block_list:
+ x1 = block(x1)
+ # add
+ x = x + x1
+
+ x = self.net.last_second_conv(x)
+ x = self.net.pool(x)
+
+ x = self.net.last_conv(x)
+ x = hardswish(x)
+ x = self.net.dropout(x)
+ x = paddle.flatten(x, start_axis=1, stop_axis=-1)
+ x = self.net.out(x)
+ return x
+
+
+def siamese_mv3(class_dim, use_custom_relu):
+ model = SiameseMV3(
+ scale=0.5,
+ model_name="large",
+ class_dim=class_dim,
+ use_custom_relu=use_custom_relu)
+ return model
+
+
+def build_model(config):
+ model_type = config['model_type']
+ if model_type == "cls":
+ class_dim = config['MODEL']['class_dim']
+ use_custom_relu = config['MODEL']['use_custom_relu']
+ if 'siamese' in config['MODEL'] and config['MODEL']['siamese'] is True:
+ model = siamese_mv3(
+ class_dim=class_dim, use_custom_relu=use_custom_relu)
+ else:
+ model = MobileNetV3_large_x0_5(
+ class_dim=class_dim, use_custom_relu=use_custom_relu)
+
+ elif model_type == "cls_distill":
+ class_dim = config['MODEL']['class_dim']
+ use_custom_relu = config['MODEL']['use_custom_relu']
+ model = distillmv3_large_x0_5(
+ class_dim=class_dim, use_custom_relu=use_custom_relu)
+
+ elif model_type == "cls_distill_multiopt":
+ class_dim = config['MODEL']['class_dim']
+ use_custom_relu = config['MODEL']['use_custom_relu']
+ model = distillmv3_large_x0_5(
+ class_dim=100, use_custom_relu=use_custom_relu)
+ else:
+ raise ValueError("model_type should be one of ['']")
+
+ return model
diff --git a/test_tipc/supplementary/mv3_distill.yml b/test_tipc/supplementary/mv3_distill.yml
new file mode 100644
index 0000000000000000000000000000000000000000..887b1eb17fc6ebcc8abb5a1ce80abba34daacf08
--- /dev/null
+++ b/test_tipc/supplementary/mv3_distill.yml
@@ -0,0 +1,31 @@
+
+class_dim: 100
+total_images: 50000
+epochs: 1000
+topk: 5
+save_model_dir: ./output/
+use_gpu: True
+model_type: cls_distill
+
+LEARNING_RATE:
+ function: 'Cosine'
+ params:
+ lr: 0.001
+ warmup_epoch: 5
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.00002
+
+TRAIN:
+ batch_size: 1280
+ num_workers: 4
+
+VALID:
+ batch_size: 64
+ num_workers: 4
+
diff --git a/test_tipc/supplementary/mv3_large_x0_5.yml b/test_tipc/supplementary/mv3_large_x0_5.yml
new file mode 100644
index 0000000000000000000000000000000000000000..531c2f0f50a4b79a03296095eded508ed8d4c12c
--- /dev/null
+++ b/test_tipc/supplementary/mv3_large_x0_5.yml
@@ -0,0 +1,49 @@
+
+class_dim: 100
+total_images: 50000
+epoch: 1000
+topk: 5
+save_model_dir: ./output/
+use_gpu: True
+model_type: cls
+use_custom_relu: false
+pretrained_model:
+checkpoints:
+save_model_dir: ./output/cls/
+
+# slim
+quant_train: false
+prune_train: false
+
+MODEL:
+ class_dim: 100
+ use_custom_relu: False
+ siamese: False
+
+AMP:
+ use_amp: False
+ scale_loss: 1024.0
+ use_dynamic_loss_scale: True
+
+LEARNING_RATE:
+ function: 'Cosine'
+ params:
+ lr: 0.001
+ warmup_epoch: 5
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.00002
+
+TRAIN:
+ batch_size: 1280
+ num_workers: 4
+
+VALID:
+ batch_size: 64
+ num_workers: 4
+
diff --git a/test_tipc/supplementary/optimizer.py b/test_tipc/supplementary/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa01534752ed8c3589960e7f7d92e7892b26dd7
--- /dev/null
+++ b/test_tipc/supplementary/optimizer.py
@@ -0,0 +1,325 @@
+import sys
+import math
+from paddle.optimizer.lr import LinearWarmup
+from paddle.optimizer.lr import PiecewiseDecay
+from paddle.optimizer.lr import CosineAnnealingDecay
+from paddle.optimizer.lr import ExponentialDecay
+import paddle
+import paddle.regularizer as regularizer
+from copy import deepcopy
+
+
+class Cosine(CosineAnnealingDecay):
+ """
+ Cosine learning rate decay
+ lr = 0.05 * (math.cos(epoch * (math.pi / epochs)) + 1)
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ """
+
+ def __init__(self, lr, step_each_epoch, epochs, **kwargs):
+ super(Cosine, self).__init__(
+ learning_rate=lr,
+ T_max=step_each_epoch * epochs, )
+
+ self.update_specified = False
+
+
+class Piecewise(PiecewiseDecay):
+ """
+ Piecewise learning rate decay
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ decay_epochs(list): piecewise decay epochs
+ gamma(float): decay factor
+ """
+
+ def __init__(self, lr, step_each_epoch, decay_epochs, gamma=0.1, **kwargs):
+ boundaries = [step_each_epoch * e for e in decay_epochs]
+ lr_values = [lr * (gamma**i) for i in range(len(boundaries) + 1)]
+ super(Piecewise, self).__init__(boundaries=boundaries, values=lr_values)
+
+ self.update_specified = False
+
+
+class CosineWarmup(LinearWarmup):
+ """
+ Cosine learning rate decay with warmup
+ [0, warmup_epoch): linear warmup
+ [warmup_epoch, epochs): cosine decay
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ epochs(int): total training epochs
+ warmup_epoch(int): epoch num of warmup
+ """
+
+ def __init__(self, lr, step_each_epoch, epochs, warmup_epoch=5, **kwargs):
+ assert epochs > warmup_epoch, "total epoch({}) should be larger than warmup_epoch({}) in CosineWarmup.".format(
+ epochs, warmup_epoch)
+ warmup_step = warmup_epoch * step_each_epoch
+ start_lr = 0.0
+ end_lr = lr
+ lr_sch = Cosine(lr, step_each_epoch, epochs - warmup_epoch)
+
+ super(CosineWarmup, self).__init__(
+ learning_rate=lr_sch,
+ warmup_steps=warmup_step,
+ start_lr=start_lr,
+ end_lr=end_lr)
+
+ self.update_specified = False
+
+
+class ExponentialWarmup(LinearWarmup):
+ """
+ Exponential learning rate decay with warmup
+ [0, warmup_epoch): linear warmup
+ [warmup_epoch, epochs): Exponential decay
+ Args:
+ lr(float): initial learning rate
+ step_each_epoch(int): steps each epoch
+ decay_epochs(float): decay epochs
+ decay_rate(float): decay rate
+ warmup_epoch(int): epoch num of warmup
+ """
+
+ def __init__(self,
+ lr,
+ step_each_epoch,
+ decay_epochs=2.4,
+ decay_rate=0.97,
+ warmup_epoch=5,
+ **kwargs):
+ warmup_step = warmup_epoch * step_each_epoch
+ start_lr = 0.0
+ end_lr = lr
+ lr_sch = ExponentialDecay(lr, decay_rate)
+
+ super(ExponentialWarmup, self).__init__(
+ learning_rate=lr_sch,
+ warmup_steps=warmup_step,
+ start_lr=start_lr,
+ end_lr=end_lr)
+
+ # NOTE: hac method to update exponential lr scheduler
+ self.update_specified = True
+ self.update_start_step = warmup_step
+ self.update_step_interval = int(decay_epochs * step_each_epoch)
+ self.step_each_epoch = step_each_epoch
+
+
+class LearningRateBuilder():
+ """
+ Build learning rate variable
+ https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn.html
+ Args:
+ function(str): class name of learning rate
+ params(dict): parameters used for init the class
+ """
+
+ def __init__(self,
+ function='Linear',
+ params={'lr': 0.1,
+ 'steps': 100,
+ 'end_lr': 0.0}):
+ self.function = function
+ self.params = params
+
+ def __call__(self):
+ mod = sys.modules[__name__]
+ lr = getattr(mod, self.function)(**self.params)
+ return lr
+
+
+class L1Decay(object):
+ """
+ L1 Weight Decay Regularization, which encourages the weights to be sparse.
+ Args:
+ factor(float): regularization coeff. Default:0.0.
+ """
+
+ def __init__(self, factor=0.0):
+ super(L1Decay, self).__init__()
+ self.factor = factor
+
+ def __call__(self):
+ reg = regularizer.L1Decay(self.factor)
+ return reg
+
+
+class L2Decay(object):
+ """
+ L2 Weight Decay Regularization, which encourages the weights to be sparse.
+ Args:
+ factor(float): regularization coeff. Default:0.0.
+ """
+
+ def __init__(self, factor=0.0):
+ super(L2Decay, self).__init__()
+ self.factor = factor
+
+ def __call__(self):
+ reg = regularizer.L2Decay(self.factor)
+ return reg
+
+
+class Momentum(object):
+ """
+ Simple Momentum optimizer with velocity state.
+ Args:
+ learning_rate (float|Variable) - The learning rate used to update parameters.
+ Can be a float value or a Variable with one float value as data element.
+ momentum (float) - Momentum factor.
+ regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
+ """
+
+ def __init__(self,
+ learning_rate,
+ momentum,
+ parameter_list=None,
+ regularization=None,
+ **args):
+ super(Momentum, self).__init__()
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.parameter_list = parameter_list
+ self.regularization = regularization
+
+ def __call__(self):
+ opt = paddle.optimizer.Momentum(
+ learning_rate=self.learning_rate,
+ momentum=self.momentum,
+ parameters=self.parameter_list,
+ weight_decay=self.regularization)
+ return opt
+
+
+class RMSProp(object):
+ """
+ Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
+ Args:
+ learning_rate (float|Variable) - The learning rate used to update parameters.
+ Can be a float value or a Variable with one float value as data element.
+ momentum (float) - Momentum factor.
+ rho (float) - rho value in equation.
+ epsilon (float) - avoid division by zero, default is 1e-6.
+ regularization (WeightDecayRegularizer, optional) - The strategy of regularization.
+ """
+
+ def __init__(self,
+ learning_rate,
+ momentum,
+ rho=0.95,
+ epsilon=1e-6,
+ parameter_list=None,
+ regularization=None,
+ **args):
+ super(RMSProp, self).__init__()
+ self.learning_rate = learning_rate
+ self.momentum = momentum
+ self.rho = rho
+ self.epsilon = epsilon
+ self.parameter_list = parameter_list
+ self.regularization = regularization
+
+ def __call__(self):
+ opt = paddle.optimizer.RMSProp(
+ learning_rate=self.learning_rate,
+ momentum=self.momentum,
+ rho=self.rho,
+ epsilon=self.epsilon,
+ parameters=self.parameter_list,
+ weight_decay=self.regularization)
+ return opt
+
+
+class OptimizerBuilder(object):
+ """
+ Build optimizer
+ Args:
+ function(str): optimizer name of learning rate
+ params(dict): parameters used for init the class
+ regularizer (dict): parameters used for create regularization
+ """
+
+ def __init__(self,
+ function='Momentum',
+ params={'momentum': 0.9},
+ regularizer=None):
+ self.function = function
+ self.params = params
+ # create regularizer
+ if regularizer is not None:
+ mod = sys.modules[__name__]
+ reg_func = regularizer['function'] + 'Decay'
+ del regularizer['function']
+ reg = getattr(mod, reg_func)(**regularizer)()
+ self.params['regularization'] = reg
+
+ def __call__(self, learning_rate, parameter_list=None):
+ mod = sys.modules[__name__]
+ opt = getattr(mod, self.function)
+ return opt(learning_rate=learning_rate,
+ parameter_list=parameter_list,
+ **self.params)()
+
+
+def create_optimizer(config, parameter_list=None):
+ """
+ Create an optimizer using config, usually including
+ learning rate and regularization.
+
+ Args:
+ config(dict): such as
+ {
+ 'LEARNING_RATE':
+ {'function': 'Cosine',
+ 'params': {'lr': 0.1}
+ },
+ 'OPTIMIZER':
+ {'function': 'Momentum',
+ 'params':{'momentum': 0.9},
+ 'regularizer':
+ {'function': 'L2', 'factor': 0.0001}
+ }
+ }
+
+ Returns:
+ an optimizer instance
+ """
+ # create learning_rate instance
+ lr_config = config['LEARNING_RATE']
+ lr_config['params'].update({
+ 'epochs': config['epoch'],
+ 'step_each_epoch':
+ config['total_images'] // config['TRAIN']['batch_size'],
+ })
+ lr = LearningRateBuilder(**lr_config)()
+
+ # create optimizer instance
+ opt_config = deepcopy(config['OPTIMIZER'])
+
+ opt = OptimizerBuilder(**opt_config)
+ return opt(lr, parameter_list), lr
+
+
+def create_multi_optimizer(config, parameter_list=None):
+ """
+ """
+ # create learning_rate instance
+ lr_config = config['LEARNING_RATE']
+ lr_config['params'].update({
+ 'epochs': config['epoch'],
+ 'step_each_epoch':
+ config['total_images'] // config['TRAIN']['batch_size'],
+ })
+ lr = LearningRateBuilder(**lr_config)()
+
+ # create optimizer instance
+ opt_config = deepcopy.copy(config['OPTIMIZER'])
+ opt = OptimizerBuilder(**opt_config)
+ return opt(lr, parameter_list), lr
diff --git a/test_tipc/supplementary/readme.md b/test_tipc/supplementary/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..0d35f9451f5004498cdbd001edfb2dfe2244ebb7
--- /dev/null
+++ b/test_tipc/supplementary/readme.md
@@ -0,0 +1,67 @@
+
+# TIPC Linux端补充训练功能测试
+
+Linux端基础训练预测功能测试的主程序为test_train_python.sh,可以测试基于Python的模型训练、评估等基本功能,包括裁剪、量化、蒸馏训练。
+
+![](./tipc_train.png)
+
+测试链条如上图所示,主要测试内容有带共享权重,自定义OP的模型的正常训练和slim相关功能训练流程是否正常。
+
+
+# 2. 测试流程
+
+本节介绍补充链条的测试流程
+
+## 2.1 安装依赖
+
+- 安装PaddlePaddle >= 2.2
+- 安装其他依赖
+
+```
+pip3 install -r requirements.txt
+```
+
+## 2.2 功能测试
+
+`test_train_python.sh`包含2种运行模式,每种模式的运行数据不同,分别用于测试训练是否正常,分别是:
+
+- 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
+
+```
+bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer'
+```
+
+- 模式2:whole_train_whole_infer,使用全量数据训练,用于快速验证训练到预测的走通流程,验证模型最终训练精度;
+
+```
+bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer'
+```
+
+如果是运行量化裁剪等训练方式,需要使用不同的配置文件。量化训练的测试指令如下:
+```
+bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_PACT.txt 'lite_train_lite_infer'
+```
+
+同理,FPGM裁剪的运行方式如下:
+```
+bash test_tipc/test_train_python.sh ./test_tipc/ch_ppocr_mobile_v2.0_det/train_infer_python_FPGM.txt 'lite_train_lite_infer'
+```
+
+运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式运行后,在test_tipc/extra_output文件夹有以下文件:
+
+```
+test_tipc/output/
+|- results_python.log # 运行指令状态的日志
+```
+
+其中results_python.log中包含了每条指令的运行状态,如果运行成功会输出:
+
+```
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=20 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=False !
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls MODEL.siamese=True !
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=False !
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill MODEL.siamese=True !
+Run successfully with command - python3.7 train.py -c mv3_large_x0_5.yml -o use_gpu=True epoch=2 AMP.use_amp=True TRAIN.batch_size=1280 use_custom_relu=False model_type=cls_distill_multiopt MODEL.siamese=False !
+
+```
diff --git a/test_tipc/supplementary/requirements.txt b/test_tipc/supplementary/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..c55500a7c4434386b8f5363714056aeac9258710
--- /dev/null
+++ b/test_tipc/supplementary/requirements.txt
@@ -0,0 +1 @@
+paddleslim==2.2.1
diff --git a/test_tipc/supplementary/slim/__init__.py b/test_tipc/supplementary/slim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/test_tipc/supplementary/slim/slim_fpgm.py b/test_tipc/supplementary/slim/slim_fpgm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e7621592da88b568eb3b035376135c04f47c787
--- /dev/null
+++ b/test_tipc/supplementary/slim/slim_fpgm.py
@@ -0,0 +1,22 @@
+import paddleslim
+import paddle
+import numpy as np
+
+from paddleslim.dygraph import FPGMFilterPruner
+
+
+def prune_model(model, input_shape, prune_ratio=0.1):
+
+ flops = paddle.flops(model, input_shape)
+ pruner = FPGMFilterPruner(model, input_shape)
+
+ params_sensitive = {}
+ for param in model.parameters():
+ if 'transpose' not in param.name and 'linear' not in param.name:
+ # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
+ params_sensitive[param.name] = prune_ratio
+
+ plan = pruner.prune_vars(params_sensitive, [0])
+
+ flops = paddle.flops(model, input_shape)
+ return model
diff --git a/test_tipc/supplementary/slim/slim_quant.py b/test_tipc/supplementary/slim/slim_quant.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c201bf55dcbb94995f80a0658f6fad1956749de
--- /dev/null
+++ b/test_tipc/supplementary/slim/slim_quant.py
@@ -0,0 +1,48 @@
+import paddle
+import numpy as np
+import os
+import paddle.nn as nn
+import paddleslim
+
+
+class PACT(paddle.nn.Layer):
+ def __init__(self):
+ super(PACT, self).__init__()
+ alpha_attr = paddle.ParamAttr(
+ name=self.full_name() + ".pact",
+ initializer=paddle.nn.initializer.Constant(value=20),
+ learning_rate=1.0,
+ regularizer=paddle.regularizer.L2Decay(2e-5))
+
+ self.alpha = self.create_parameter(
+ shape=[1], attr=alpha_attr, dtype='float32')
+
+ def forward(self, x):
+ out_left = paddle.nn.functional.relu(x - self.alpha)
+ out_right = paddle.nn.functional.relu(-self.alpha - x)
+ x = x - out_left + out_right
+ return x
+
+
+quant_config = {
+ # weight preprocess type, default is None and no preprocessing is performed.
+ 'weight_preprocess_type': None,
+ # activation preprocess type, default is None and no preprocessing is performed.
+ 'activation_preprocess_type': None,
+ # weight quantize type, default is 'channel_wise_abs_max'
+ 'weight_quantize_type': 'channel_wise_abs_max',
+ # activation quantize type, default is 'moving_average_abs_max'
+ 'activation_quantize_type': 'moving_average_abs_max',
+ # weight quantize bit num, default is 8
+ 'weight_bits': 8,
+ # activation quantize bit num, default is 8
+ 'activation_bits': 8,
+ # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
+ 'dtype': 'int8',
+ # window size for 'range_abs_max' quantization. default is 10000
+ 'window_size': 10000,
+ # The decay coefficient of moving average, default is 0.9
+ 'moving_rate': 0.9,
+ # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
+ 'quantizable_layer_type': ['Conv2D', 'Linear'],
+}
diff --git a/test_tipc/supplementary/test_tipc/common_func.sh b/test_tipc/supplementary/test_tipc/common_func.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2ff5c4d75845ba4c77ff890725b27db48a450fe
--- /dev/null
+++ b/test_tipc/supplementary/test_tipc/common_func.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+function func_parser_key(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[0]}
+ echo ${tmp}
+}
+
+function func_parser_value(){
+ strs=$1
+ IFS=":"
+ array=(${strs})
+ tmp=${array[1]}
+ echo ${tmp}
+}
+
+function func_set_params(){
+ key=$1
+ value=$2
+ if [ ${key}x = "null"x ];then
+ echo " "
+ elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
+ echo " "
+ else
+ echo "${key}=${value}"
+ fi
+}
+
+function func_parser_params(){
+ strs=$1
+ MODE=$2
+ IFS=":"
+ array=(${strs})
+ key=${array[0]}
+ tmp=${array[1]}
+ IFS="|"
+ res=""
+ for _params in ${tmp[*]}; do
+ IFS="="
+ array=(${_params})
+ mode=${array[0]}
+ value=${array[1]}
+ if [[ ${mode} = ${MODE} ]]; then
+ IFS="|"
+ #echo $(func_set_params "${mode}" "${value}")
+ echo $value
+ break
+ fi
+ IFS="|"
+ done
+ echo ${res}
+}
+
+function status_check(){
+ last_status=$1 # the exit code
+ run_command=$2
+ run_log=$3
+ if [ $last_status -eq 0 ]; then
+ echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
+ else
+ echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
+ fi
+}
\ No newline at end of file
diff --git a/test_tipc/supplementary/test_tipc/test_train_python.sh b/test_tipc/supplementary/test_tipc/test_train_python.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f922b57bba7de97d3631524c6f1bd1fac7395e76
--- /dev/null
+++ b/test_tipc/supplementary/test_tipc/test_train_python.sh
@@ -0,0 +1,117 @@
+#!/bin/bash
+source test_tipc/common_func.sh
+
+FILENAME=$1
+# MODE be one of ['lite_train_lite_infer' 'lite_train_whole_infer']
+MODE=$2
+
+dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
+
+# parser params
+IFS=$'\n'
+lines=(${dataline})
+
+model_name=$(func_parser_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+gpu_list=$(func_parser_value "${lines[3]}")
+train_use_gpu_key=$(func_parser_key "${lines[4]}")
+train_use_gpu_value=$(func_parser_value "${lines[4]}")
+autocast_list=$(func_parser_value "${lines[5]}")
+autocast_key=$(func_parser_key "${lines[5]}")
+epoch_key=$(func_parser_key "${lines[6]}")
+epoch_num=$(func_parser_params "${lines[6]}" "${MODE}")
+save_model_key=$(func_parser_key "${lines[7]}")
+train_batch_key=$(func_parser_key "${lines[8]}")
+train_batch_value=$(func_parser_params "${lines[8]}" "${MODE}")
+pretrain_model_key=$(func_parser_key "${lines[9]}")
+pretrain_model_value=$(func_parser_value "${lines[9]}")
+checkpoints_key=$(func_parser_key "${lines[10]}")
+checkpoints_value=$(func_parser_value "${lines[10]}")
+use_custom_key=$(func_parser_key "${lines[11]}")
+use_custom_list=$(func_parser_value "${lines[11]}")
+model_type_key=$(func_parser_key "${lines[12]}")
+model_type_list=$(func_parser_value "${lines[12]}")
+use_share_conv_key=$(func_parser_key "${lines[13]}")
+use_share_conv_list=$(func_parser_value "${lines[13]}")
+run_train_py=$(func_parser_value "${lines[14]}")
+
+
+LOG_PATH="./test_tipc/extra_output"
+mkdir -p ${LOG_PATH}
+status_log="${LOG_PATH}/results_python.log"
+
+if [ ${MODE} = "lite_train_lite_infer" ] || [ ${MODE} = "whole_train_whole_infer" ]; then
+ IFS="|"
+ export Count=0
+ USE_GPU_KEY=(${train_use_gpu_value})
+ # select cpu\gpu\distribute training
+ for gpu in ${gpu_list[*]}; do
+ train_use_gpu=${USE_GPU_KEY[Count]}
+ Count=$(($Count + 1))
+ ips=""
+ if [ ${gpu} = "-1" ];then
+ env=""
+ elif [ ${#gpu} -le 1 ];then
+ env="export CUDA_VISIBLE_DEVICES=${gpu}"
+ eval ${env}
+ elif [ ${#gpu} -le 15 ];then
+ IFS=","
+ array=(${gpu})
+ env="export CUDA_VISIBLE_DEVICES=${array[0]}"
+ IFS="|"
+ else
+ IFS=";"
+ array=(${gpu})
+ ips=${array[0]}
+ gpu=${array[1]}
+ IFS="|"
+ env=" "
+ fi
+ for autocast in ${autocast_list[*]}; do
+ # set amp
+ if [ ${autocast} = "amp" ]; then
+ set_amp_config="AMP.use_amp=True"
+ else
+ set_amp_config=" "
+ fi
+
+ if [ ${run_train_py} = "null" ]; then
+ continue
+ fi
+
+ set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
+ set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
+ set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
+ set_checkpoints=$(func_set_params "${checkpoints_key}" "${checkpoints_value}")
+ set_batchsize=$(func_set_params "${train_batch_key}" "${train_batch_value}")
+ set_use_gpu=$(func_set_params "${train_use_gpu_key}" "${train_use_gpu}")
+
+ for custom_op in ${use_custom_list[*]}; do
+ for model_type in ${model_type_list[*]}; do
+ for share_conv in ${use_share_conv_list[*]}; do
+ set_use_custom_op=$(func_set_params "${use_custom_key}" "${custom_op}")
+ set_model_type=$(func_set_params "${model_type_key}" "${model_type}")
+ set_use_share_conv=$(func_set_params "${use_share_conv_key}" "${share_conv}")
+
+ set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
+ if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
+ cmd="${python} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
+ elif [ ${#ips} -le 26 ];then # train with multi-gpu
+ cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train_py} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_checkpoints} ${set_autocast} ${set_batchsize} ${set_use_custom_op} ${set_model_type} ${set_use_share_conv} ${set_amp_config}"
+ fi
+
+ # run train
+ eval "unset CUDA_VISIBLE_DEVICES"
+ # echo $cmd
+ eval $cmd
+ status_check $? "${cmd}" "${status_log}"
+ done
+ done
+ done
+ done
+ done
+fi
+
+
+
+
diff --git a/test_tipc/supplementary/test_tipc/tipc_train.png b/test_tipc/supplementary/test_tipc/tipc_train.png
new file mode 100644
index 0000000000000000000000000000000000000000..9ca124ebe69706cedcd59e64831e62ec0f230e23
Binary files /dev/null and b/test_tipc/supplementary/test_tipc/tipc_train.png differ
diff --git a/test_tipc/supplementary/test_tipc/train_infer_python.txt b/test_tipc/supplementary/test_tipc/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..99028c0c49d16c53d18528ad761b68f39ba4f151
--- /dev/null
+++ b/test_tipc/supplementary/test_tipc/train_infer_python.txt
@@ -0,0 +1,17 @@
+===========================train_params===========================
+model_name:ch_PPOCRv2_det
+python:python3.7
+gpu_list:0|0,1
+use_gpu:True|True
+AMP.use_amp:True|False
+epoch:lite_train_lite_infer=2|whole_train_whole_infer=1000
+save_model_dir:./output/
+TRAIN.batch_size:lite_train_lite_infer=1280|whole_train_whole_infer=1280
+pretrained_model:null
+checkpoints:null
+use_custom_relu:False|True
+model_type:cls|cls_distill|cls_distill_multiopt
+MODEL.siamese:False|True
+norm_train:train.py -c mv3_large_x0_5.yml -o
+quant_train:False
+prune_train:False
diff --git a/test_tipc/supplementary/test_tipc/train_infer_python_FPGM.txt b/test_tipc/supplementary/test_tipc/train_infer_python_FPGM.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c2e28b91e24b34d1bded93cddebe83e0874ae29
--- /dev/null
+++ b/test_tipc/supplementary/test_tipc/train_infer_python_FPGM.txt
@@ -0,0 +1,17 @@
+===========================train_params===========================
+model_name:ch_PPOCRv2_det
+python:python3.7
+gpu_list:0|0,1
+use_gpu:True|True
+AMP.use_amp:True|False
+epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
+save_model_dir:./output/
+TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
+pretrained_model:null
+checkpoints:null
+use_custom_relu:False|True
+model_type:cls|cls_distill|cls_distill_multiopt
+MODEL.siamese:False|True
+norm_train:train.py -c mv3_large_x0_5.yml -o prune_train=True
+quant_train:False
+prune_train:False
diff --git a/test_tipc/supplementary/test_tipc/train_infer_python_PACT.txt b/test_tipc/supplementary/test_tipc/train_infer_python_PACT.txt
new file mode 100644
index 0000000000000000000000000000000000000000..079cddf878712b2ba3af3a19f97be3bb5a0896da
--- /dev/null
+++ b/test_tipc/supplementary/test_tipc/train_infer_python_PACT.txt
@@ -0,0 +1,17 @@
+===========================train_params===========================
+model_name:ch_PPOCRv2_det
+python:python3.7
+gpu_list:0|0,1
+use_gpu:True|True
+AMP.use_amp:True|False
+epoch:lite_train_lite_infer=20|whole_train_whole_infer=1000
+save_model_dir:./output/
+TRAIN.batch_size:lite_train_lite_infer=2|whole_train_whole_infer=4
+pretrained_model:null
+checkpoints:null
+use_custom_relu:False|True
+model_type:cls|cls_distill|cls_distill_multiopt
+MODEL.siamese:False|True
+norm_train:train.py -c mv3_large_x0_5.yml -o quant_train=True
+quant_train:False
+prune_train:False
diff --git a/test_tipc/supplementary/train.py b/test_tipc/supplementary/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e632d1d1803a85144bc750c3ff6ff51b1eb65973
--- /dev/null
+++ b/test_tipc/supplementary/train.py
@@ -0,0 +1,474 @@
+import paddle
+import numpy as np
+import os
+import paddle.nn as nn
+import paddle.distributed as dist
+dist.get_world_size()
+dist.init_parallel_env()
+
+from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
+from optimizer import create_optimizer
+from data_loader import build_dataloader
+from metric import create_metric
+from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
+from config import preprocess
+import time
+
+from paddleslim.dygraph.quant import QAT
+from slim.slim_quant import PACT, quant_config
+from slim.slim_fpgm import prune_model
+from utils import load_model
+
+
+def _mkdir_if_not_exist(path, logger):
+ """
+ mkdir if not exists, ignore the exception when multiprocess mkdir together
+ """
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno == errno.EEXIST and os.path.isdir(path):
+ logger.warning(
+ 'be happy if some process has already created {}'.format(
+ path))
+ else:
+ raise OSError('Failed to mkdir {}'.format(path))
+
+
+def save_model(model,
+ optimizer,
+ model_path,
+ logger,
+ is_best=False,
+ prefix='ppocr',
+ **kwargs):
+ """
+ save model to the target path
+ """
+ _mkdir_if_not_exist(model_path, logger)
+ model_prefix = os.path.join(model_path, prefix)
+ paddle.save(model.state_dict(), model_prefix + '.pdparams')
+ if type(optimizer) is list:
+ paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
+ paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
+
+ else:
+ paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
+
+ # # save metric and config
+ # with open(model_prefix + '.states', 'wb') as f:
+ # pickle.dump(kwargs, f, protocol=2)
+ if is_best:
+ logger.info('save best model is to {}'.format(model_prefix))
+ else:
+ logger.info("save model in {}".format(model_prefix))
+
+
+def amp_scaler(config):
+ if 'AMP' in config and config['AMP']['use_amp'] is True:
+ AMP_RELATED_FLAGS_SETTING = {
+ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
+ 'FLAGS_max_inplace_grad_add': 8,
+ }
+ paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
+ scale_loss = config["AMP"].get("scale_loss", 1.0)
+ use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
+ False)
+ scaler = paddle.amp.GradScaler(
+ init_loss_scaling=scale_loss,
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ return scaler
+ else:
+ return None
+
+
+def set_seed(seed):
+ paddle.seed(seed)
+ np.random.seed(seed)
+
+
+def train(config, scaler=None):
+ EPOCH = config['epoch']
+ topk = config['topk']
+
+ batch_size = config['TRAIN']['batch_size']
+ num_workers = config['TRAIN']['num_workers']
+ train_loader = build_dataloader(
+ 'train', batch_size=batch_size, num_workers=num_workers)
+
+ # build metric
+ metric_func = create_metric
+
+ # build model
+ # model = MobileNetV3_large_x0_5(class_dim=100)
+ model = build_model(config)
+
+ # build_optimizer
+ optimizer, lr_scheduler = create_optimizer(
+ config, parameter_list=model.parameters())
+
+ # load model
+ pre_best_model_dict = load_model(config, model, optimizer)
+ if len(pre_best_model_dict) > 0:
+ pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
+ ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ logger.info(pre_str)
+
+ # about slim prune and quant
+ if "quant_train" in config and config['quant_train'] is True:
+ quanter = QAT(config=quant_config, act_preprocess=PACT)
+ quanter.quantize(model)
+ elif "prune_train" in config and config['prune_train'] is True:
+ model = prune_model(model, [1, 3, 32, 32], 0.1)
+ else:
+ pass
+
+ # distribution
+ model.train()
+ model = paddle.DataParallel(model)
+ # build loss function
+ loss_func = build_loss(config)
+
+ data_num = len(train_loader)
+
+ best_acc = {}
+ for epoch in range(EPOCH):
+ st = time.time()
+ for idx, data in enumerate(train_loader):
+ img_batch, label = data
+ img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
+ label = paddle.unsqueeze(label, -1)
+
+ if scaler is not None:
+ with paddle.amp.auto_cast():
+ outs = model(img_batch)
+ else:
+ outs = model(img_batch)
+
+ # cal metric
+ acc = metric_func(outs, label)
+
+ # cal loss
+ avg_loss = loss_func(outs, label)
+
+ if scaler is None:
+ # backward
+ avg_loss.backward()
+ optimizer.step()
+ optimizer.clear_grad()
+ else:
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
+
+ if not isinstance(lr_scheduler, float):
+ lr_scheduler.step()
+
+ if idx % 10 == 0:
+ et = time.time()
+ strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
+ strs += f"loss: {avg_loss.numpy()[0]}"
+ strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
+ strs += f", batch_time: {round(et-st, 4)} s"
+ logger.info(strs)
+ st = time.time()
+
+ if epoch % 10 == 0:
+ acc = eval(config, model)
+ if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
+ best_acc = acc
+ best_acc['epoch'] = epoch
+ is_best = True
+ else:
+ is_best = False
+ logger.info(
+ f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
+ )
+ save_model(
+ model,
+ optimizer,
+ config['save_model_dir'],
+ logger,
+ is_best,
+ prefix="cls")
+
+
+def train_distill(config, scaler=None):
+ EPOCH = config['epoch']
+ topk = config['topk']
+
+ batch_size = config['TRAIN']['batch_size']
+ num_workers = config['TRAIN']['num_workers']
+ train_loader = build_dataloader(
+ 'train', batch_size=batch_size, num_workers=num_workers)
+
+ # build metric
+ metric_func = create_metric
+
+ # model = distillmv3_large_x0_5(class_dim=100)
+ model = build_model(config)
+
+ # pact quant train
+ if "quant_train" in config and config['quant_train'] is True:
+ quanter = QAT(config=quant_config, act_preprocess=PACT)
+ quanter.quantize(model)
+ elif "prune_train" in config and config['prune_train'] is True:
+ model = prune_model(model, [1, 3, 32, 32], 0.1)
+ else:
+ pass
+
+ # build_optimizer
+ optimizer, lr_scheduler = create_optimizer(
+ config, parameter_list=model.parameters())
+
+ # load model
+ pre_best_model_dict = load_model(config, model, optimizer)
+ if len(pre_best_model_dict) > 0:
+ pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
+ ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ logger.info(pre_str)
+
+ model.train()
+ model = paddle.DataParallel(model)
+
+ # build loss function
+ loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
+ loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
+ loss_func_js = KLJSLoss(mode='js')
+
+ data_num = len(train_loader)
+
+ best_acc = {}
+ for epoch in range(EPOCH):
+ st = time.time()
+ for idx, data in enumerate(train_loader):
+ img_batch, label = data
+ img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
+ label = paddle.unsqueeze(label, -1)
+ if scaler is not None:
+ with paddle.amp.auto_cast():
+ outs = model(img_batch)
+ else:
+ outs = model(img_batch)
+
+ # cal metric
+ acc = metric_func(outs['student'], label)
+
+ # cal loss
+ avg_loss = loss_func_distill(outs, label)['student'] + \
+ loss_func_distill(outs, label)['student1'] + \
+ loss_func_dml(outs, label)['student_student1']
+
+ # backward
+ if scaler is None:
+ avg_loss.backward()
+ optimizer.step()
+ optimizer.clear_grad()
+ else:
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
+
+ if not isinstance(lr_scheduler, float):
+ lr_scheduler.step()
+
+ if idx % 10 == 0:
+ et = time.time()
+ strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
+ strs += f"loss: {avg_loss.numpy()[0]}"
+ strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
+ strs += f", batch_time: {round(et-st, 4)} s"
+ logger.info(strs)
+ st = time.time()
+
+ if epoch % 10 == 0:
+ acc = eval(config, model._layers.student)
+ if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
+ best_acc = acc
+ best_acc['epoch'] = epoch
+ is_best = True
+ else:
+ is_best = False
+ logger.info(
+ f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
+ )
+
+ save_model(
+ model,
+ optimizer,
+ config['save_model_dir'],
+ logger,
+ is_best,
+ prefix="cls_distill")
+
+
+def train_distill_multiopt(config, scaler=None):
+ EPOCH = config['epoch']
+ topk = config['topk']
+
+ batch_size = config['TRAIN']['batch_size']
+ num_workers = config['TRAIN']['num_workers']
+ train_loader = build_dataloader(
+ 'train', batch_size=batch_size, num_workers=num_workers)
+
+ # build metric
+ metric_func = create_metric
+
+ # model = distillmv3_large_x0_5(class_dim=100)
+ model = build_model(config)
+
+ # build_optimizer
+ optimizer, lr_scheduler = create_optimizer(
+ config, parameter_list=model.student.parameters())
+ optimizer1, lr_scheduler1 = create_optimizer(
+ config, parameter_list=model.student1.parameters())
+
+ # load model
+ pre_best_model_dict = load_model(config, model, optimizer)
+ if len(pre_best_model_dict) > 0:
+ pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
+ ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
+ logger.info(pre_str)
+
+ # quant train
+ if "quant_train" in config and config['quant_train'] is True:
+ quanter = QAT(config=quant_config, act_preprocess=PACT)
+ quanter.quantize(model)
+ elif "prune_train" in config and config['prune_train'] is True:
+ model = prune_model(model, [1, 3, 32, 32], 0.1)
+ else:
+ pass
+
+ model.train()
+
+ model = paddle.DataParallel(model)
+
+ # build loss function
+ loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
+ loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
+ loss_func_js = KLJSLoss(mode='js')
+
+ data_num = len(train_loader)
+ best_acc = {}
+ for epoch in range(EPOCH):
+ st = time.time()
+ for idx, data in enumerate(train_loader):
+ img_batch, label = data
+ img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
+ label = paddle.unsqueeze(label, -1)
+
+ if scaler is not None:
+ with paddle.amp.auto_cast():
+ outs = model(img_batch)
+ else:
+ outs = model(img_batch)
+
+ # cal metric
+ acc = metric_func(outs['student'], label)
+
+ # cal loss
+ avg_loss = loss_func_distill(outs,
+ label)['student'] + loss_func_dml(
+ outs, label)['student_student1']
+ avg_loss1 = loss_func_distill(outs,
+ label)['student1'] + loss_func_dml(
+ outs, label)['student_student1']
+
+ if scaler is None:
+ # backward
+ avg_loss.backward(retain_graph=True)
+ optimizer.step()
+ optimizer.clear_grad()
+
+ avg_loss1.backward()
+ optimizer1.step()
+ optimizer1.clear_grad()
+ else:
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
+
+ scaled_avg_loss = scaler.scale(avg_loss1)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer1, scaled_avg_loss)
+
+ if not isinstance(lr_scheduler, float):
+ lr_scheduler.step()
+ if not isinstance(lr_scheduler1, float):
+ lr_scheduler1.step()
+
+ if idx % 10 == 0:
+ et = time.time()
+ strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
+ strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
+ strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
+ strs += f", batch_time: {round(et-st, 4)} s"
+ logger.info(strs)
+ st = time.time()
+
+ if epoch % 10 == 0:
+ acc = eval(config, model._layers.student)
+ if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
+ best_acc = acc
+ best_acc['epoch'] = epoch
+ is_best = True
+ else:
+ is_best = False
+ logger.info(
+ f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
+ )
+ save_model(
+ model, [optimizer, optimizer1],
+ config['save_model_dir'],
+ logger,
+ is_best,
+ prefix="cls_distill_multiopt")
+
+
+def eval(config, model):
+ batch_size = config['VALID']['batch_size']
+ num_workers = config['VALID']['num_workers']
+ valid_loader = build_dataloader(
+ 'test', batch_size=batch_size, num_workers=num_workers)
+
+ # build metric
+ metric_func = create_metric
+
+ outs = []
+ labels = []
+ for idx, data in enumerate(valid_loader):
+ img_batch, label = data
+ img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
+ label = paddle.unsqueeze(label, -1)
+ out = model(img_batch)
+
+ outs.append(out)
+ labels.append(label)
+
+ outs = paddle.concat(outs, axis=0)
+ labels = paddle.concat(labels, axis=0)
+ acc = metric_func(outs, labels)
+
+ strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
+ logger.info(strs)
+ return acc
+
+
+if __name__ == "__main__":
+
+ config, logger = preprocess(is_train=False)
+
+ # AMP scaler
+ scaler = amp_scaler(config)
+
+ model_type = config['model_type']
+
+ if model_type == "cls":
+ train(config)
+ elif model_type == "cls_distill":
+ train_distill(config)
+ elif model_type == "cls_distill_multiopt":
+ train_distill_multiopt(config)
+ else:
+ raise ValueError("model_type should be one of ['']")
diff --git a/test_tipc/supplementary/train.sh b/test_tipc/supplementary/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a2c7c90ccc137f50fd3be6d2ce3a2bd081446a7e
--- /dev/null
+++ b/test_tipc/supplementary/train.sh
@@ -0,0 +1,5 @@
+# single GPU
+python3.7 train.py -c mv3_large_x0_5.yml
+
+# distribute training
+python3.7 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1' train.py -c mv3_large_x0_5.yml
diff --git a/test_tipc/supplementary/utils.py b/test_tipc/supplementary/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae9ae061b93bc43dc14151c203ac8226c5e64aec
--- /dev/null
+++ b/test_tipc/supplementary/utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import logging
+import functools
+import paddle.distributed as dist
+
+logger_initialized = {}
+
+
+def print_dict(d, logger, delimiter=0):
+ """
+ Recursively visualize a dict and
+ indenting acrrording by the relationship of keys.
+ """
+ for k, v in sorted(d.items()):
+ if isinstance(v, dict):
+ logger.info("{}{} : ".format(delimiter * " ", str(k)))
+ print_dict(v, logger, delimiter + 4)
+ elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
+ logger.info("{}{} : ".format(delimiter * " ", str(k)))
+ for value in v:
+ print_dict(value, logger, delimiter + 4)
+ else:
+ logger.info("{}{} : {}".format(delimiter * " ", k, v))
+
+
+@functools.lru_cache()
+def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified a FileHandler will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ formatter = logging.Formatter(
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
+ datefmt="%Y/%m/%d %H:%M:%S")
+
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if log_file is not None and dist.get_rank() == 0:
+ log_file_folder = os.path.split(log_file)[0]
+ os.makedirs(log_file_folder, exist_ok=True)
+ file_handler = logging.FileHandler(log_file, 'a')
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ if dist.get_rank() == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+ logger_initialized[name] = True
+ return logger
+
+
+def load_model(config, model, optimizer=None):
+ """
+ load model from checkpoint or pretrained_model
+ """
+ logger = get_logger()
+ checkpoints = config.get('checkpoints')
+ pretrained_model = config.get('pretrained_model')
+ best_model_dict = {}
+ if checkpoints:
+ if checkpoints.endswith('.pdparams'):
+ checkpoints = checkpoints.replace('.pdparams', '')
+ assert os.path.exists(checkpoints + ".pdparams"), \
+ "The {}.pdparams does not exists!".format(checkpoints)
+
+ # load params from trained model
+ params = paddle.load(checkpoints + '.pdparams')
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if key not in params:
+ logger.warning("{} not in loaded params {} !".format(
+ key, params.keys()))
+ continue
+ pre_value = params[key]
+ if list(value.shape) == list(pre_value.shape):
+ new_state_dict[key] = pre_value
+ else:
+ logger.warning(
+ "The shape of model params {} {} not matched with loaded params shape {} !".
+ format(key, value.shape, pre_value.shape))
+ model.set_state_dict(new_state_dict)
+
+ if optimizer is not None:
+ if os.path.exists(checkpoints + '.pdopt'):
+ optim_dict = paddle.load(checkpoints + '.pdopt')
+ optimizer.set_state_dict(optim_dict)
+ else:
+ logger.warning(
+ "{}.pdopt is not exists, params of optimizer is not loaded".
+ format(checkpoints))
+
+ if os.path.exists(checkpoints + '.states'):
+ with open(checkpoints + '.states', 'rb') as f:
+ states_dict = pickle.load(f) if six.PY2 else pickle.load(
+ f, encoding='latin1')
+ best_model_dict = states_dict.get('best_model_dict', {})
+ if 'epoch' in states_dict:
+ best_model_dict['start_epoch'] = states_dict['epoch'] + 1
+ logger.info("resume from {}".format(checkpoints))
+ elif pretrained_model:
+ load_pretrained_params(model, pretrained_model)
+ else:
+ logger.info('train from scratch')
+ return best_model_dict
+
+
+def load_pretrained_params(model, path):
+ logger = get_logger()
+ if path.endswith('.pdparams'):
+ path = path.replace('.pdparams', '')
+ assert os.path.exists(path + ".pdparams"), \
+ "The {}.pdparams does not exists!".format(path)
+
+ params = paddle.load(path + '.pdparams')
+ state_dict = model.state_dict()
+ new_state_dict = {}
+ for k1 in params.keys():
+ if k1 not in state_dict.keys():
+ logger.warning("The pretrained params {} not in model".format(k1))
+ else:
+ if list(state_dict[k1].shape) == list(params[k1].shape):
+ new_state_dict[k1] = params[k1]
+ else:
+ logger.warning(
+ "The shape of model params {} {} not matched with loaded params {} {} !".
+ format(k1, state_dict[k1].shape, k1, params[k1].shape))
+ model.set_state_dict(new_state_dict)
+ logger.info("load pretrain successful from {}".format(path))
+ return model
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index b9bf9edf309c02fde0a679891b709deef6da9465..9bde89d78e0ee78c7b650306047b036488a3eab9 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -183,7 +183,7 @@ function func_inference(){
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
continue
fi
- if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
+ if [[ ${use_trt} = "False" && ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
continue
fi
for batch_size in ${batch_size_list[*]}; do
@@ -227,7 +227,12 @@ if [ ${MODE} = "whole_infer" ] || [ ${MODE} = "klquant_whole_infer" ]; then
for infer_model in ${infer_model_dir_list[*]}; do
# run export
if [ ${infer_run_exports[Count]} != "null" ];then
- save_infer_dir=$(dirname $infer_model)
+ if [ ${MODE} = "klquant_whole_infer" ]; then
+ save_infer_dir="${infer_model}_klquant"
+ fi
+ if [ ${MODE} = "whole_infer" ]; then
+ save_infer_dir="${infer_model}"
+ fi
set_export_weight=$(func_set_params "${export_weight}" "${infer_model}")
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_dir}")
export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}"
diff --git a/tools/eval.py b/tools/eval.py
index 13a4a0882f5a20b47e8999042713e1623b32ff5a..3a25c2660d5558e2afa5215e275fec65f78d7c1c 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -61,7 +61,8 @@ def main():
else:
model_type = None
- best_model_dict = load_model(config, model)
+ best_model_dict = load_model(
+ config, model, model_type=config['Architecture']["model_type"])
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
diff --git a/tools/export_model.py b/tools/export_model.py
index 9ed8e1b6ace89ded030c946870551c8e078d7340..695af5c8bd092ec9a0ef806f8170cc686b194b73 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -85,7 +85,7 @@ def export_single_model(model, arch_config, save_path, logger):
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
- merge_config(FLAGS.opt)
+ config = merge_config(config, FLAGS.opt)
logger = get_logger()
# build post process
diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py
new file mode 100755
index 0000000000000000000000000000000000000000..5859c28f92085bda67627af2a10acc56cb36d932
--- /dev/null
+++ b/tools/infer_vqa_token_ser.py
@@ -0,0 +1,135 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import json
+import paddle
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.visual import draw_ser_results
+from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps
+import tools.program as program
+
+
+def to_tensor(data):
+ import numbers
+ from collections import defaultdict
+ data_dict = defaultdict(list)
+ to_tensor_idxs = []
+ for idx, v in enumerate(data):
+ if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
+ if idx not in to_tensor_idxs:
+ to_tensor_idxs.append(idx)
+ data_dict[idx].append(v)
+ for idx in to_tensor_idxs:
+ data_dict[idx] = paddle.to_tensor(data_dict[idx])
+ return list(data_dict.values())
+
+
+class SerPredictor(object):
+ def __init__(self, config):
+ global_config = config['Global']
+
+ # build post process
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ self.model = build_model(config['Architecture'])
+
+ load_model(
+ config, self.model, model_type=config['Architecture']["model_type"])
+
+ from paddleocr import PaddleOCR
+
+ self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+
+ # create data ops
+ transforms = []
+ for op in config['Eval']['dataset']['transforms']:
+ op_name = list(op)[0]
+ if 'Label' in op_name:
+ op[op_name]['ocr_engine'] = self.ocr_engine
+ elif op_name == 'KeepKeys':
+ op[op_name]['keep_keys'] = [
+ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
+ 'token_type_ids', 'segment_offset_id', 'ocr_info',
+ 'entities'
+ ]
+
+ transforms.append(op)
+ global_config['infer_mode'] = True
+ self.ops = create_operators(config['Eval']['dataset']['transforms'],
+ global_config)
+ self.model.eval()
+
+ def __call__(self, img_path):
+ with open(img_path, 'rb') as f:
+ img = f.read()
+ data = {'image': img}
+ batch = transform(data, self.ops)
+ batch = to_tensor(batch)
+ preds = self.model(batch)
+ post_result = self.post_process_class(
+ preds,
+ attention_masks=batch[4],
+ segment_offset_ids=batch[6],
+ ocr_infos=batch[7])
+ return post_result, batch
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+
+ ser_engine = SerPredictor(config)
+
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ with open(
+ os.path.join(config['Global']['save_res_path'],
+ "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ save_img_path = os.path.join(
+ config['Global']['save_res_path'],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
+
+ result, _ = ser_engine(img_path)
+ result = result[0]
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ocr_info": result,
+ }, ensure_ascii=False) + "\n")
+ img_res = draw_ser_results(img_path, result)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd62ace8aef35db168537580513139e429e88cc3
--- /dev/null
+++ b/tools/infer_vqa_token_ser_re.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+import cv2
+import json
+import paddle
+import paddle.distributed as dist
+
+from ppocr.data import create_operators, transform
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import load_model
+from ppocr.utils.visual import draw_re_results
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
+from tools.program import ArgsParser, load_config, merge_config, check_gpu
+from tools.infer_vqa_token_ser import SerPredictor
+
+
+class ReArgsParser(ArgsParser):
+ def __init__(self):
+ super(ReArgsParser, self).__init__()
+ self.add_argument(
+ "-c_ser", "--config_ser", help="ser configuration file to use")
+ self.add_argument(
+ "-o_ser",
+ "--opt_ser",
+ nargs='+',
+ help="set ser configuration options ")
+
+ def parse_args(self, argv=None):
+ args = super(ReArgsParser, self).parse_args(argv)
+ assert args.config_ser is not None, \
+ "Please specify --config_ser=ser_configure_file_path."
+ args.opt_ser = self._parse_opt(args.opt_ser)
+ return args
+
+
+def make_input(ser_inputs, ser_results):
+ entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
+
+ entities = ser_inputs[8][0]
+ ser_results = ser_results[0]
+ assert len(entities) == len(ser_results)
+
+ # entities
+ start = []
+ end = []
+ label = []
+ entity_idx_dict = {}
+ for i, (res, entity) in enumerate(zip(ser_results, entities)):
+ if res['pred'] == 'O':
+ continue
+ entity_idx_dict[len(start)] = i
+ start.append(entity['start'])
+ end.append(entity['end'])
+ label.append(entities_labels[res['pred']])
+ entities = dict(start=start, end=end, label=label)
+
+ # relations
+ head = []
+ tail = []
+ for i in range(len(entities["label"])):
+ for j in range(len(entities["label"])):
+ if entities["label"][i] == 1 and entities["label"][j] == 2:
+ head.append(i)
+ tail.append(j)
+
+ relations = dict(head=head, tail=tail)
+
+ batch_size = ser_inputs[0].shape[0]
+ entities_batch = []
+ relations_batch = []
+ entity_idx_dict_batch = []
+ for b in range(batch_size):
+ entities_batch.append(entities)
+ relations_batch.append(relations)
+ entity_idx_dict_batch.append(entity_idx_dict)
+
+ ser_inputs[8] = entities_batch
+ ser_inputs.append(relations_batch)
+ # remove ocr_info segment_offset_id and label in ser input
+ ser_inputs.pop(7)
+ ser_inputs.pop(6)
+ ser_inputs.pop(1)
+ return ser_inputs, entity_idx_dict_batch
+
+
+class SerRePredictor(object):
+ def __init__(self, config, ser_config):
+ self.ser_engine = SerPredictor(ser_config)
+
+ # init re model
+ global_config = config['Global']
+
+ # build post process
+ self.post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ self.model = build_model(config['Architecture'])
+
+ load_model(
+ config, self.model, model_type=config['Architecture']["model_type"])
+
+ self.model.eval()
+
+ def __call__(self, img_path):
+ ser_results, ser_inputs = self.ser_engine(img_path)
+ paddle.save(ser_inputs, 'ser_inputs.npy')
+ paddle.save(ser_results, 'ser_results.npy')
+ re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
+ preds = self.model(re_input)
+ post_result = self.post_process_class(
+ preds,
+ ser_results=ser_results,
+ entity_idx_dict_batch=entity_idx_dict_batch)
+ return post_result
+
+
+def preprocess():
+ FLAGS = ReArgsParser().parse_args()
+ config = load_config(FLAGS.config)
+ config = merge_config(config, FLAGS.opt)
+
+ ser_config = load_config(FLAGS.config_ser)
+ ser_config = merge_config(ser_config, FLAGS.opt_ser)
+
+ logger = get_logger(name='root')
+
+ # check if set use_gpu=True in paddlepaddle cpu version
+ use_gpu = config['Global']['use_gpu']
+ check_gpu(use_gpu)
+
+ device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
+ device = paddle.set_device(device)
+
+ logger.info('{} re config {}'.format('*' * 10, '*' * 10))
+ print_dict(config, logger)
+ logger.info('\n')
+ logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
+ print_dict(ser_config, logger)
+ logger.info('train with paddle {} and device {}'.format(paddle.__version__,
+ device))
+ return config, ser_config, device, logger
+
+
+if __name__ == '__main__':
+ config, ser_config, device, logger = preprocess()
+ os.makedirs(config['Global']['save_res_path'], exist_ok=True)
+
+ ser_re_engine = SerRePredictor(config, ser_config)
+
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ with open(
+ os.path.join(config['Global']['save_res_path'],
+ "infer_results.txt"),
+ "w",
+ encoding='utf-8') as fout:
+ for idx, img_path in enumerate(infer_imgs):
+ save_img_path = os.path.join(
+ config['Global']['save_res_path'],
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
+
+ result = ser_re_engine(img_path)
+ result = result[0]
+ fout.write(img_path + "\t" + json.dumps(
+ {
+ "ser_resule": result,
+ }, ensure_ascii=False) + "\n")
+ img_res = draw_re_results(img_path, result)
+ cv2.imwrite(save_img_path, img_res)
diff --git a/tools/program.py b/tools/program.py
index 333e8ed9770cad08ba5e9aa47edec850a74a1808..10299940d61dd0c7b6df770e7441d3c6551954a9 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -69,24 +69,6 @@ class ArgsParser(ArgumentParser):
return config
-class AttrDict(dict):
- """Single level attribute dict, NOT recursive"""
-
- def __init__(self, **kwargs):
- super(AttrDict, self).__init__()
- super(AttrDict, self).update(kwargs)
-
- def __getattr__(self, key):
- if key in self:
- return self[key]
- raise AttributeError("object has no attribute '{}'".format(key))
-
-
-global_config = AttrDict()
-
-default_config = {'Global': {'debug': False, }}
-
-
def load_config(file_path):
"""
Load config from yml/yaml file.
@@ -94,38 +76,38 @@ def load_config(file_path):
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
- merge_config(default_config)
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
- return global_config
+ config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
+ return config
-def merge_config(config):
+def merge_config(config, opts):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
- for key, value in config.items():
+ for key, value in opts.items():
if "." not in key:
- if isinstance(value, dict) and key in global_config:
- global_config[key].update(value)
+ if isinstance(value, dict) and key in config:
+ config[key].update(value)
else:
- global_config[key] = value
+ config[key] = value
else:
sub_keys = key.split('.')
assert (
- sub_keys[0] in global_config
+ sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
- global_config.keys(), sub_keys[0])
- cur = global_config[sub_keys[0]]
+ config.keys(), sub_keys[0])
+ cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
+ return config
def check_gpu(use_gpu):
@@ -204,20 +186,24 @@ def train(config,
model_type = None
algorithm = config['Architecture']['algorithm']
- if 'start_epoch' in best_model_dict:
- start_epoch = best_model_dict['start_epoch']
- else:
- start_epoch = 1
+ start_epoch = best_model_dict[
+ 'start_epoch'] if 'start_epoch' in best_model_dict else 1
+
+ train_reader_cost = 0.0
+ train_run_cost = 0.0
+ total_samples = 0
+ reader_start = time.time()
+
+ max_iter = len(train_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(train_dataloader)
for epoch in range(start_epoch, epoch_num + 1):
- train_dataloader = build_dataloader(
- config, 'Train', device, logger, seed=epoch)
- train_reader_cost = 0.0
- train_run_cost = 0.0
- total_samples = 0
- reader_start = time.time()
- max_iter = len(train_dataloader) - 1 if platform.system(
- ) == "Windows" else len(train_dataloader)
+ if train_dataloader.dataset.need_reset:
+ train_dataloader = build_dataloader(
+ config, 'Train', device, logger, seed=epoch)
+ max_iter = len(train_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(train_dataloader)
+
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start
@@ -239,10 +225,11 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
- elif model_type == "kie":
+ elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
+
loss = loss_class(preds, batch)
avg_loss = loss['loss']
@@ -256,6 +243,7 @@ def train(config,
optimizer.clear_grad()
train_run_cost += time.time() - train_start
+ global_step += 1
total_samples += len(images)
if not isinstance(lr_scheduler, float):
@@ -285,12 +273,13 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log()
- strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
+ strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, (train_reader_cost + train_run_cost) /
- print_batch_step, total_samples,
+ print_batch_step, total_samples / print_batch_step,
total_samples / (train_reader_cost + train_run_cost))
logger.info(strs)
+
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
@@ -330,6 +319,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
+ config,
is_best=True,
prefix='best_accuracy',
best_model_dict=best_model_dict,
@@ -344,8 +334,7 @@ def train(config,
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
best_model_dict[main_indicator],
global_step)
- global_step += 1
- optimizer.clear_grad()
+
reader_start = time.time()
if dist.get_rank() == 0:
save_model(
@@ -353,6 +342,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
+ config,
is_best=False,
prefix='latest',
best_model_dict=best_model_dict,
@@ -364,6 +354,7 @@ def train(config,
optimizer,
save_model_dir,
logger,
+ config,
is_best=False,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
@@ -401,19 +392,28 @@ def eval(model,
start = time.time()
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
- elif model_type == "kie":
+ elif model_type in ["kie", 'vqa']:
preds = model(batch)
else:
preds = model(images)
- batch = [item.numpy() for item in batch]
+
+ batch_numpy = []
+ for item in batch:
+ if isinstance(item, paddle.Tensor):
+ batch_numpy.append(item.numpy())
+ else:
+ batch_numpy.append(item)
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
if model_type in ['table', 'kie']:
- eval_class(preds, batch)
+ eval_class(preds, batch_numpy)
+ elif model_type in ['vqa']:
+ post_result = post_process_class(preds, batch_numpy)
+ eval_class(post_result, batch_numpy)
else:
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
+ post_result = post_process_class(preds, batch_numpy[1])
+ eval_class(post_result, batch_numpy)
pbar.update(1)
total_frame += len(images)
@@ -479,9 +479,9 @@ def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
- merge_config(FLAGS.opt)
+ config = merge_config(config, FLAGS.opt)
profile_dic = {"profiler_options": FLAGS.profiler_options}
- merge_config(profile_dic)
+ config = merge_config(config, profile_dic)
if is_train:
# save_config
@@ -503,20 +503,15 @@ def preprocess(is_train=False):
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR'
+ 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
]
- windows_not_support_list = ['PSE']
- if platform.system() == "Windows" and alg in windows_not_support_list:
- logger.warning('{} is not support in Windows now'.format(
- windows_not_support_list))
- sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
- if config['Global']['use_visualdl']:
+ if config['Global']['use_visualdl'] and dist.get_rank() == 0:
from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
diff --git a/tools/train.py b/tools/train.py
index f3852469eb198ebfec13713fc4d8f139b2c10f2b..506e0f7fa87fe8afc82cbb12d553a8da4ba298e2 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -27,8 +27,6 @@ import yaml
import paddle
import paddle.distributed as dist
-paddle.seed(2)
-
from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.losses import build_loss
@@ -36,6 +34,7 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
from ppocr.utils.save_load import load_model
+from ppocr.utils.utility import set_seed
import tools.program as program
dist.get_world_size()
@@ -97,7 +96,8 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
- pre_best_model_dict = load_model(config, model, optimizer)
+ pre_best_model_dict = load_model(config, model, optimizer,
+ config['Architecture']["model_type"])
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
@@ -145,5 +145,7 @@ def test_reader(config, device, logger):
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True)
+ seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
+ set_seed(seed)
main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)