提交 315c07d5 编写于 作者: L LDOUBLEV

fix conflicts

...@@ -106,7 +106,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -106,7 +106,7 @@ class MainWindow(QMainWindow, WindowMixin):
getStr = lambda strId: self.stringBundle.getString(strId) getStr = lambda strId: self.stringBundle.getString(strId)
self.defaultSaveDir = defaultSaveDir self.defaultSaveDir = defaultSaveDir
self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=True, lang=lang) self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=False, lang=lang)
if os.path.exists('./data/paddle.png'): if os.path.exists('./data/paddle.png'):
result = self.ocr.ocr('./data/paddle.png', cls=True, det=True) result = self.ocr.ocr('./data/paddle.png', cls=True, det=True)
...@@ -274,6 +274,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -274,6 +274,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.preButton.setIconSize(QSize(40, 100)) self.preButton.setIconSize(QSize(40, 100))
self.preButton.clicked.connect(self.openPrevImg) self.preButton.clicked.connect(self.openPrevImg)
self.preButton.setStyleSheet('border: none;') self.preButton.setStyleSheet('border: none;')
self.preButton.setShortcut('a')
self.iconlist = QListWidget() self.iconlist = QListWidget()
self.iconlist.setViewMode(QListView.IconMode) self.iconlist.setViewMode(QListView.IconMode)
self.iconlist.setFlow(QListView.TopToBottom) self.iconlist.setFlow(QListView.TopToBottom)
...@@ -289,12 +290,12 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -289,12 +290,12 @@ class MainWindow(QMainWindow, WindowMixin):
self.nextButton.setIconSize(QSize(40, 100)) self.nextButton.setIconSize(QSize(40, 100))
self.nextButton.setStyleSheet('border: none;') self.nextButton.setStyleSheet('border: none;')
self.nextButton.clicked.connect(self.openNextImg) self.nextButton.clicked.connect(self.openNextImg)
self.nextButton.setShortcut('d')
hlayout.addWidget(self.preButton) hlayout.addWidget(self.preButton)
hlayout.addWidget(self.iconlist) hlayout.addWidget(self.iconlist)
hlayout.addWidget(self.nextButton) hlayout.addWidget(self.nextButton)
# self.setLayout(hlayout)
iconListContainer = QWidget() iconListContainer = QWidget()
iconListContainer.setLayout(hlayout) iconListContainer.setLayout(hlayout)
...@@ -359,11 +360,6 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -359,11 +360,6 @@ class MainWindow(QMainWindow, WindowMixin):
opendir = action(getStr('openDir'), self.openDirDialog, opendir = action(getStr('openDir'), self.openDirDialog,
'Ctrl+u', 'open', getStr('openDir')) 'Ctrl+u', 'open', getStr('openDir'))
openNextImg = action(getStr('nextImg'), self.openNextImg,
'd', 'next', getStr('nextImgDetail'))
openPrevImg = action(getStr('prevImg'), self.openPrevImg,
'a', 'prev', getStr('prevImgDetail'))
save = action(getStr('save'), self.saveFile, save = action(getStr('save'), self.saveFile,
'Ctrl+V', 'verify', getStr('saveDetail'), enabled=False) 'Ctrl+V', 'verify', getStr('saveDetail'), enabled=False)
...@@ -371,7 +367,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -371,7 +367,7 @@ class MainWindow(QMainWindow, WindowMixin):
alcm = action(getStr('choosemodel'), self.autolcm, alcm = action(getStr('choosemodel'), self.autolcm,
'Ctrl+M', 'next', getStr('tipchoosemodel')) 'Ctrl+M', 'next', getStr('tipchoosemodel'))
deleteImg = action(getStr('deleteImg'), self.deleteImg, 'Ctrl+D', 'close', getStr('deleteImgDetail'), deleteImg = action(getStr('deleteImg'), self.deleteImg, 'Ctrl+Shift+D', 'close', getStr('deleteImgDetail'),
enabled=True) enabled=True)
resetAll = action(getStr('resetAll'), self.resetAll, None, 'resetall', getStr('resetAllDetail')) resetAll = action(getStr('resetAll'), self.resetAll, None, 'resetall', getStr('resetAllDetail'))
...@@ -388,7 +384,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -388,7 +384,7 @@ class MainWindow(QMainWindow, WindowMixin):
'w', 'new', getStr('crtBoxDetail'), enabled=False) 'w', 'new', getStr('crtBoxDetail'), enabled=False)
delete = action(getStr('delBox'), self.deleteSelectedShape, delete = action(getStr('delBox'), self.deleteSelectedShape,
'Delete', 'delete', getStr('delBoxDetail'), enabled=False) 'backspace', 'delete', getStr('delBoxDetail'), enabled=False)
copy = action(getStr('dupBox'), self.copySelectedShape, copy = action(getStr('dupBox'), self.copySelectedShape,
'Ctrl+C', 'copy', getStr('dupBoxDetail'), 'Ctrl+C', 'copy', getStr('dupBoxDetail'),
enabled=False) enabled=False)
...@@ -446,8 +442,11 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -446,8 +442,11 @@ class MainWindow(QMainWindow, WindowMixin):
reRec = action(getStr('reRecognition'), self.reRecognition, reRec = action(getStr('reRecognition'), self.reRecognition,
'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False) 'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False)
singleRere = action(getStr('singleRe'), self.singleRerecognition,
'Ctrl+R', 'reRec', getStr('singleRe'), enabled=False)
createpoly = action(getStr('creatPolygon'), self.createPolygon, createpoly = action(getStr('creatPolygon'), self.createPolygon,
'p', 'new', 'Creat Polygon', enabled=True) 'q', 'new', 'Creat Polygon', enabled=True)
saveRec = action(getStr('saveRec'), self.saveRecResult, saveRec = action(getStr('saveRec'), self.saveRecResult,
'', 'save', getStr('saveRec'), enabled=False) '', 'save', getStr('saveRec'), enabled=False)
...@@ -491,6 +490,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -491,6 +490,7 @@ class MainWindow(QMainWindow, WindowMixin):
icon='color', tip=getStr('shapeFillColorDetail'), icon='color', tip=getStr('shapeFillColorDetail'),
enabled=False) enabled=False)
# Label list context menu. # Label list context menu.
labelMenu = QMenu() labelMenu = QMenu()
addActions(labelMenu, (edit, delete)) addActions(labelMenu, (edit, delete))
...@@ -501,7 +501,6 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -501,7 +501,6 @@ class MainWindow(QMainWindow, WindowMixin):
# Draw squares/rectangles # Draw squares/rectangles
self.drawSquaresOption = QAction(getStr('drawSquares'), self) self.drawSquaresOption = QAction(getStr('drawSquares'), self)
self.drawSquaresOption.setShortcut('Ctrl+Shift+R')
self.drawSquaresOption.setCheckable(True) self.drawSquaresOption.setCheckable(True)
self.drawSquaresOption.setChecked(settings.get(SETTING_DRAW_SQUARE, False)) self.drawSquaresOption.setChecked(settings.get(SETTING_DRAW_SQUARE, False))
self.drawSquaresOption.triggered.connect(self.toogleDrawSquare) self.drawSquaresOption.triggered.connect(self.toogleDrawSquare)
...@@ -509,7 +508,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -509,7 +508,7 @@ class MainWindow(QMainWindow, WindowMixin):
# Store actions for further handling. # Store actions for further handling.
self.actions = struct(save=save, open=open, resetAll=resetAll, deleteImg=deleteImg, self.actions = struct(save=save, open=open, resetAll=resetAll, deleteImg=deleteImg,
lineColor=color1, create=create, delete=delete, edit=edit, copy=copy, lineColor=color1, create=create, delete=delete, edit=edit, copy=copy,
saveRec=saveRec, saveRec=saveRec, singleRere=singleRere,AutoRec=AutoRec,reRec=reRec,
createMode=createMode, editMode=editMode, createMode=createMode, editMode=editMode,
shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor, shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg, zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
...@@ -518,9 +517,9 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -518,9 +517,9 @@ class MainWindow(QMainWindow, WindowMixin):
fileMenuActions=( fileMenuActions=(
open, opendir, saveLabel, resetAll, quit), open, opendir, saveLabel, resetAll, quit),
beginner=(), advanced=(), beginner=(), advanced=(),
editMenu=(createpoly, edit, copy, delete, editMenu=(createpoly, edit, copy, delete,singleRere,
None, color1, self.drawSquaresOption), None, color1, self.drawSquaresOption),
beginnerContext=(create, edit, copy, delete), beginnerContext=(create, edit, copy, delete, singleRere),
advancedContext=(createMode, editMode, edit, copy, advancedContext=(createMode, editMode, edit, copy,
delete, shapeLineColor, shapeFillColor), delete, shapeLineColor, shapeFillColor),
onLoadActive=( onLoadActive=(
...@@ -562,7 +561,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -562,7 +561,7 @@ class MainWindow(QMainWindow, WindowMixin):
zoomIn, zoomOut, zoomOrg, None, zoomIn, zoomOut, zoomOrg, None,
fitWindow, fitWidth)) fitWindow, fitWidth))
addActions(self.menus.autolabel, (alcm, None, help)) # addActions(self.menus.autolabel, (AutoRec, reRec, alcm, None, help)) #
self.menus.file.aboutToShow.connect(self.updateFileMenu) self.menus.file.aboutToShow.connect(self.updateFileMenu)
...@@ -572,6 +571,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -572,6 +571,7 @@ class MainWindow(QMainWindow, WindowMixin):
action('&Copy here', self.copyShape), action('&Copy here', self.copyShape),
action('&Move here', self.moveShape))) action('&Move here', self.moveShape)))
self.statusBar().showMessage('%s started.' % __appname__) self.statusBar().showMessage('%s started.' % __appname__)
self.statusBar().show() self.statusBar().show()
...@@ -919,6 +919,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -919,6 +919,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.actions.edit.setEnabled(selected) self.actions.edit.setEnabled(selected)
self.actions.shapeLineColor.setEnabled(selected) self.actions.shapeLineColor.setEnabled(selected)
self.actions.shapeFillColor.setEnabled(selected) self.actions.shapeFillColor.setEnabled(selected)
self.actions.singleRere.setEnabled(selected)
def addLabel(self, shape): def addLabel(self, shape):
shape.paintLabel = self.displayLabelOption.isChecked() shape.paintLabel = self.displayLabelOption.isChecked()
...@@ -988,6 +989,19 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -988,6 +989,19 @@ class MainWindow(QMainWindow, WindowMixin):
self.updateComboBox() self.updateComboBox()
self.canvas.loadShapes(s) self.canvas.loadShapes(s)
def singleLabel(self, shape):
if shape is None:
# print('rm empty label')
return
item = self.shapesToItems[shape]
item.setText(shape.label)
self.updateComboBox()
# ADD:
item = self.shapesToItemsbox[shape]
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
self.updateComboBox()
def updateComboBox(self): def updateComboBox(self):
# Get the unique labels and add them to the Combobox. # Get the unique labels and add them to the Combobox.
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())] itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
...@@ -1441,6 +1455,8 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1441,6 +1455,8 @@ class MainWindow(QMainWindow, WindowMixin):
self.haveAutoReced = False self.haveAutoReced = False
self.AutoRecognition.setEnabled(True) self.AutoRecognition.setEnabled(True)
self.reRecogButton.setEnabled(True) self.reRecogButton.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
self.actions.reRec.setEnabled(True)
self.actions.saveLabel.setEnabled(True) self.actions.saveLabel.setEnabled(True)
...@@ -1755,6 +1771,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1755,6 +1771,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.loadFile(self.filePath) # ADD self.loadFile(self.filePath) # ADD
self.haveAutoReced = True self.haveAutoReced = True
self.AutoRecognition.setEnabled(False) self.AutoRecognition.setEnabled(False)
self.actions.AutoRec.setEnabled(False)
self.setDirty() self.setDirty()
self.saveCacheLabel() self.saveCacheLabel()
...@@ -1794,6 +1811,27 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1794,6 +1811,27 @@ class MainWindow(QMainWindow, WindowMixin):
else: else:
QMessageBox.information(self, "Information", "Draw a box!") QMessageBox.information(self, "Information", "Draw a box!")
def singleRerecognition(self):
img = cv2.imread(self.filePath)
shape = self.canvas.selectedShape
box = [[int(p.x()), int(p.y())] for p in shape.points]
assert len(box) == 4
img_crop = get_rotate_crop_image(img, np.array(box, np.float32))
if img_crop is None:
msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
if result[0][0] is not '':
result.insert(0, box)
print('result in reRec is ', result)
if result[1][0] == shape.label:
print('label no change')
else:
shape.label = result[1][0]
self.singleLabel(shape)
self.setDirty()
print(box)
def autolcm(self): def autolcm(self):
vbox = QVBoxLayout() vbox = QVBoxLayout()
...@@ -1825,6 +1863,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1825,6 +1863,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.dialog.exec_() self.dialog.exec_()
if self.filePath: if self.filePath:
self.AutoRecognition.setEnabled(True) self.AutoRecognition.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
def modelChoose(self): def modelChoose(self):
......
...@@ -6,6 +6,10 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field. I ...@@ -6,6 +6,10 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field. I
<img src="./data/gif/steps_en.gif" width="100%"/> <img src="./data/gif/steps_en.gif" width="100%"/>
### Recent Update
- 2020.12.18: Support re-recognition of a single label box (by [ninetailskim](https://github.com/ninetailskim) ), perfect shortcut keys.
## Installation ## Installation
### 1. Install PaddleOCR ### 1. Install PaddleOCR
...@@ -92,6 +96,25 @@ Therefore, if the recognition result has been manually changed before, it may ch ...@@ -92,6 +96,25 @@ Therefore, if the recognition result has been manually changed before, it may ch
## Explanation ## Explanation
### Shortcut keys
| Shortcut keys | Description |
| ---------------- | ------------------------------------------------ |
| Ctrl + shift + A | Automatically label all unchecked images |
| Ctrl + shift + R | Re-recognize all the labels of the current image |
| W | Create a rect box |
| Q | Create a four-points box |
| Ctrl + E | Edit label of the selected box |
| Ctrl + R | Re-recognize the selected box |
| Backspace | Delete the selected box |
| Ctrl + V | Check image |
| Ctrl + Shift + d | Delete image |
| D | Next image |
| A | Previous image |
| Ctrl++ | Zoom in |
| Ctrl-- | Zoom out |
| ↑→↓← | Move selected box |
### Built-in Model ### Built-in Model
- Default model: PPOCRLabel uses the Chinese and English ultra-lightweight OCR model in PaddleOCR by default, supports Chinese, English and number recognition, and multiple language detection. - Default model: PPOCRLabel uses the Chinese and English ultra-lightweight OCR model in PaddleOCR by default, supports Chinese, English and number recognition, and multiple language detection.
......
...@@ -6,6 +6,10 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,使用p ...@@ -6,6 +6,10 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,使用p
<img src="./data/gif/steps.gif" width="100%"/> <img src="./data/gif/steps.gif" width="100%"/>
#### 近期更新
- 2020.12.18: 支持对单个标记框进行重新识别(by [ninetailskim](https://github.com/ninetailskim) ),完善快捷键。
## 安装 ## 安装
### 1. 安装PaddleOCR ### 1. 安装PaddleOCR
...@@ -72,6 +76,26 @@ python3 PPOCRLabel.py --lang ch ...@@ -72,6 +76,26 @@ python3 PPOCRLabel.py --lang ch
| crop_img | 识别数据。按照检测框切割后的图片。与rec_gt.txt同时产生。 | | crop_img | 识别数据。按照检测框切割后的图片。与rec_gt.txt同时产生。 |
## 说明 ## 说明
### 快捷键
| 快捷键 | 说明 |
| ---------------- | ---------------------------- |
| Ctrl + shift + A | 自动标注所有未确认过的图片 |
| Ctrl + shift + R | 对当前图片的所有标记重新识别 |
| W | 新建矩形框 |
| Q | 新建四点框 |
| Ctrl + E | 编辑所选框标签 |
| Ctrl + R | 重新识别所选标记 |
| Backspace | 删除所选框 |
| Ctrl + V | 确认本张图片标记 |
| Ctrl + Shift + d | 删除本张图片 |
| D | 下一张图片 |
| A | 上一张图片 |
| Ctrl++ | 缩小 |
| Ctrl-- | 放大 |
| ↑→↓← | 移动标记框 |
### 内置模型 ### 内置模型
- 默认模型:PPOCRLabel默认使用PaddleOCR中的中英文超轻量OCR模型,支持中英文与数字识别,多种语言检测。 - 默认模型:PPOCRLabel默认使用PaddleOCR中的中英文超轻量OCR模型,支持中英文与数字识别,多种语言检测。
......
...@@ -46,8 +46,9 @@ class Worker(QThread): ...@@ -46,8 +46,9 @@ class Worker(QThread):
chars = res[1][0] chars = res[1][0]
cond = res[1][1] cond = res[1][1]
posi = res[0] posi = res[0]
strs += "Transcription: " + chars + " Probability: " + str( strs += "Transcription: " + chars + " Probability: " + str(cond) + \
cond) + " Location: " + json.dumps(posi) + '\n' " Location: " + json.dumps(posi) +'\n'
# Sending large amounts of data repeatedly through pyqtSignal may affect the program efficiency
self.listValue.emit(strs) self.listValue.emit(strs)
self.mainThread.result_dic = self.result_dic self.mainThread.result_dic = self.result_dic
self.mainThread.filePath = Imgpath self.mainThread.filePath = Imgpath
......
此差异已折叠。
...@@ -95,3 +95,4 @@ autolabeling=自动标注中 ...@@ -95,3 +95,4 @@ autolabeling=自动标注中
hideBox=隐藏所有标注 hideBox=隐藏所有标注
showBox=显示所有标注 showBox=显示所有标注
saveLabel=保存标记结果 saveLabel=保存标记结果
singleRe=重识别此区块
\ No newline at end of file
saveAsDetail=將標籤保存到其他文件
changeSaveDir=改變存放目錄
openFile=開啟檔案
shapeLineColorDetail=更改線條顏色
resetAll=重置
crtBox=創建區塊
crtBoxDetail=畫一個區塊
dupBoxDetail=複製區塊
verifyImg=驗證圖像
zoominDetail=放大
verifyImgDetail=驗證圖像
saveDetail=將標籤存到
openFileDetail=打開圖像
fitWidthDetail=調整到窗口寬度
tutorial=YouTube教學
editLabel=編輯標籤
openAnnotationDetail=打開標籤文件
quit=結束
shapeFillColorDetail=更改填充顏色
closeCurDetail=關閉目前檔案
closeCur=關閉
deleteImg=刪除圖像
deleteImgDetail=刪除目前圖像
fitWin=調整到跟窗口一樣大小
delBox=刪除選取區塊
boxLineColorDetail=選擇框線顏色
originalsize=原始大小
resetAllDetail=重設所有設定
zoomoutDetail=畫面放大
save=儲存
saveAs=另存為
fitWinDetail=縮放到窗口一樣
openDir=開啟目錄
copyPrevBounding=複製當前圖像中的上一個邊界框
showHide=顯示/隱藏標籤
changeSaveFormat=更改儲存格式
shapeFillColor=填充顏色
quitApp=離開本程式
dupBox=複製區塊
delBoxDetail=刪除區塊
zoomin=放大畫面
info=資訊
openAnnotation=開啟標籤
prevImgDetail=上一個圖像
fitWidth=縮放到跟畫面一樣寬
zoomout=縮小畫面
changeSavedAnnotationDir=更改預設標籤存的目錄
nextImgDetail=下一個圖像
originalsizeDetail=放大到原始大小
prevImg=上一個圖像
tutorialDetail=顯示示範內容
shapeLineColor=形狀線條顏色
boxLineColor=日期分隔線顏色
editLabelDetail=修改所選區塊的標籤
nextImg=下一張圖片
useDefaultLabel=使用預設標籤
useDifficult=有難度的
boxLabelText=區塊的標籤
labels=標籤
autoSaveMode=自動儲存模式
singleClsMode=單一類別模式
displayLabel=顯示類別
fileList=檔案清單
files=檔案
iconList=XX
icon=XX
advancedMode=進階模式
advancedModeDetail=切到進階模式
showAllBoxDetail=顯示所有區塊
hideAllBoxDetail=隱藏所有區塊
...@@ -95,3 +95,4 @@ autolabeling=Automatic Labeling ...@@ -95,3 +95,4 @@ autolabeling=Automatic Labeling
hideBox=Hide All Box hideBox=Hide All Box
showBox=Show All Box showBox=Show All Box
saveLabel=Save Label saveLabel=Save Label
singleRe=Re-recognition RectBox
\ No newline at end of file
...@@ -122,8 +122,7 @@ For a new language request, please refer to [Guideline for new language_requests ...@@ -122,8 +122,7 @@ For a new language request, please refer to [Guideline for new language_requests
<img src="./doc/ppocr_framework.png" width="800"> <img src="./doc/ppocr_framework.png" width="800">
</div> </div>
PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection, detection frame correction and CRNN text recognition. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module. The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941). Besides, The implementation of the FPGM Pruner and PACT quantization is based on [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim). PP-OCR is a practical ultra-lightweight OCR system. It is mainly composed of three parts: DB text detection[2], detection frame correction and CRNN text recognition[7]. The system adopts 19 effective strategies from 8 aspects including backbone network selection and adjustment, prediction head design, data augmentation, learning rate transformation strategy, regularization parameter selection, pre-training model use, and automatic model tailoring and quantization to optimize and slim down the models of each module. The final results are an ultra-lightweight Chinese and English OCR model with an overall size of 3.5M and a 2.8M English digital OCR model. For more details, please refer to the PP-OCR technical article (https://arxiv.org/abs/2009.09941). Besides, The implementation of the FPGM Pruner [8] and PACT quantization [9] is based on [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim).
## Visualization [more](./doc/doc_en/visualization_en.md) ## Visualization [more](./doc/doc_en/visualization_en.md)
......
...@@ -8,8 +8,8 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -8,8 +8,8 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 静态图版本:develop分支 - 静态图版本:develop分支
**近期更新** **近期更新**
- 2020.12.21 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数132个,每周一都会更新,欢迎大家持续关注。
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。 - 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
- 2020.12.14 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数127个,每周一都会更新,欢迎大家持续关注。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。 - 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941 - 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
- [More](./doc/doc_ch/update.md) - [More](./doc/doc_ch/update.md)
...@@ -101,8 +101,8 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -101,8 +101,8 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- [效果展示](#效果展示) - [效果展示](#效果展示)
- FAQ - FAQ
- [【精选】OCR精选10个问题](./doc/doc_ch/FAQ.md) - [【精选】OCR精选10个问题](./doc/doc_ch/FAQ.md)
- [【理论篇】OCR通用30个问题](./doc/doc_ch/FAQ.md) - [【理论篇】OCR通用31个问题](./doc/doc_ch/FAQ.md)
- [【实战篇】PaddleOCR实战84个问题](./doc/doc_ch/FAQ.md) - [【实战篇】PaddleOCR实战91个问题](./doc/doc_ch/FAQ.md)
- [技术交流群](#欢迎加入PaddleOCR技术交流群) - [技术交流群](#欢迎加入PaddleOCR技术交流群)
- [参考文献](./doc/doc_ch/reference.md) - [参考文献](./doc/doc_ch/reference.md)
- [许可证书](#许可证书) - [许可证书](#许可证书)
...@@ -115,7 +115,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -115,7 +115,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
<img src="./doc/ppocr_framework.png" width="800"> <img src="./doc/ppocr_framework.png" width="800">
</div> </div>
PP-OCR是一个实用的超轻量OCR系统。主要由DB文本检测、检测框矫正和CRNN文本识别三部分组成。该系统从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身,最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考PP-OCR技术方案 https://arxiv.org/abs/2009.09941 。其中FPGM裁剪器和PACT量化的实现可以参考[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) PP-OCR是一个实用的超轻量OCR系统。主要由DB文本检测[2]、检测框矫正和CRNN文本识别三部分组成[7]。该系统从骨干网络选择和调整、预测头部的设计、数据增强、学习率变换策略、正则化参数选择、预训练模型使用以及模型自动裁剪量化8个方面,采用19个有效策略,对各个模块的模型进行效果调优和瘦身,最终得到整体大小为3.5M的超轻量中英文OCR和2.8M的英文数字OCR。更多细节请参考PP-OCR技术方案 https://arxiv.org/abs/2009.09941 。其中FPGM裁剪器[8]和PACT量化[9]的实现可以参考[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)
<a name="效果展示"></a> <a name="效果展示"></a>
## 效果展示 [more](./doc/doc_ch/visualization.md) ## 效果展示 [more](./doc/doc_ch/visualization.md)
......
...@@ -22,7 +22,7 @@ English | [简体中文](README_ch.md) ...@@ -22,7 +22,7 @@ English | [简体中文](README_ch.md)
</div> </div>
The Style-Text data synthesis tool is a tool based on Baidu's self-developed text editing algorithm "Editing Text in the Wild" [https://arxiv.org/abs/1908.03047](https://arxiv.org/abs/1908.03047). The Style-Text data synthesis tool is a tool based on Baidu and HUST cooperation research work, "Editing Text in the Wild" [https://arxiv.org/abs/1908.03047](https://arxiv.org/abs/1908.03047).
Different from the commonly used GAN-based data synthesis tools, the main framework of Style-Text includes: Different from the commonly used GAN-based data synthesis tools, the main framework of Style-Text includes:
* (1) Text foreground style transfer module. * (1) Text foreground style transfer module.
...@@ -78,6 +78,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_ ...@@ -78,6 +78,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_
* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction. * Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
For example, enter the following image and corpus `PaddleOCR`. For example, enter the following image and corpus `PaddleOCR`.
<div align="center"> <div align="center">
...@@ -142,6 +143,7 @@ We provide a general dataset containing Chinese, English and Korean (50,000 imag ...@@ -142,6 +143,7 @@ We provide a general dataset containing Chinese, English and Korean (50,000 imag
``` bash ``` bash
python3 tools/synth_dataset.py -c configs/dataset_config.yml python3 tools/synth_dataset.py -c configs/dataset_config.yml
``` ```
We also provide example corpus and images in `examples` folder. We also provide example corpus and images in `examples` folder.
<div align="center"> <div align="center">
<img src="examples/style_images/1.jpg" width="300"> <img src="examples/style_images/1.jpg" width="300">
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
</div> </div>
Style-Text数据合成工具是基于百度自研的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047 Style-Text数据合成工具是基于百度和华科合作研发的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图像文本风格迁移。下图是一些该数据合成工具效果图。 不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图像文本风格迁移。下图是一些该数据合成工具效果图。
...@@ -128,7 +128,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_ ...@@ -128,7 +128,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_
2. 运行`tools/synth_dataset`合成数据: 2. 运行`tools/synth_dataset`合成数据:
``` bash ``` bash
python tools/synth_dataset.py -c configs/dataset_config.yml python3 tools/synth_dataset.py -c configs/dataset_config.yml
``` ```
我们在examples目录下提供了样例图片和语料。 我们在examples目录下提供了样例图片和语料。
<div align="center"> <div align="center">
......
...@@ -107,10 +107,10 @@ make inference_lib_dist ...@@ -107,10 +107,10 @@ make inference_lib_dist
For more compilation parameter options, please refer to the official website of the Paddle C++ inference library:[https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html](https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html). For more compilation parameter options, please refer to the official website of the Paddle C++ inference library:[https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html](https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html).
* After the compilation process, you can see the following files in the folder of `build/fluid_inference_install_dir/`. * After the compilation process, you can see the following files in the folder of `build/paddle_inference_install_dir/`.
``` ```
build/fluid_inference_install_dir/ build/paddle_inference_install_dir/
|-- CMakeCache.txt |-- CMakeCache.txt
|-- paddle |-- paddle
|-- third_party |-- third_party
......
...@@ -81,14 +81,14 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -81,14 +81,14 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
else if (resize_h / 32 < 1 + 1e-5) else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32; resize_h = 32;
else else
resize_h = (resize_h / 32 - 1) * 32; resize_h = (resize_h / 32) * 32;
if (resize_w % 32 == 0) if (resize_w % 32 == 0)
resize_w = resize_w; resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5) else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32; resize_w = 32;
else else
resize_w = (resize_w / 32 - 1) * 32; resize_w = (resize_w / 32) * 32;
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
......
...@@ -10,7 +10,7 @@ max_side_len 960 ...@@ -10,7 +10,7 @@ max_side_len 960
det_db_thresh 0.3 det_db_thresh 0.3
det_db_box_thresh 0.5 det_db_box_thresh 0.5
det_db_unclip_ratio 2.0 det_db_unclip_ratio 2.0
det_model_dir ../../../deploy/cpp_infer/inference/ch_ppocr_mobile_v2.0_det_infer/ det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
# cls config # cls config
use_angle_cls 0 use_angle_cls 0
......
...@@ -9,42 +9,55 @@ ...@@ -9,42 +9,55 @@
## PaddleOCR常见问题汇总(持续更新) ## PaddleOCR常见问题汇总(持续更新)
* [近期更新(2020.12.14](#近期更新) * [近期更新(2020.12.21](#近期更新)
* [【精选】OCR精选10个问题](#OCR精选10个问题) * [【精选】OCR精选10个问题](#OCR精选10个问题)
* [【理论篇】OCR通用30个问题](#OCR通用问题) * [【理论篇】OCR通用31个问题](#OCR通用问题)
* [基础知识7题](#基础知识) * [基础知识7题](#基础知识)
* [数据集7题](#数据集2) * [数据集7题](#数据集2)
* [模型训练调优7题](#模型训练调优2) * [模型训练调优17题](#模型训练调优2)
* [预测部署9题](#预测部署2) * [【实战篇】PaddleOCR实战91个问题](#PaddleOCR实战问题)
* [【实战篇】PaddleOCR实战87个问题](#PaddleOCR实战问题) * [使用咨询24题](#使用咨询)
* [使用咨询21题](#使用咨询)
* [数据集17题](#数据集3) * [数据集17题](#数据集3)
* [模型训练调优25题](#模型训练调优3) * [模型训练调优25题](#模型训练调优3)
* [预测部署24](#预测部署3) * [预测部署25](#预测部署3)
<a name="近期更新"></a> <a name="近期更新"></a>
## 近期更新(2020.12.14 ## 近期更新(2020.12.21
#### Q3.1.21:PaddleOCR支持动态图吗 #### Q2.3.17: StyleText 合成数据效果不好
**A**: StyleText模型生成的数据主要用于OCR识别模型的训练。PaddleOCR目前识别模型的输入为32 x N,因此当前版本模型主要适用高度为32的数据。
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注 建议要合成的数据尺寸设置为32 x N。尺寸相差不多的数据也可以生成,尺寸很大或很小的数据效果确实不佳
#### Q3.3.23:检测模型训练或预测时出现elementwise_add报错 #### Q3.1.22: ModuleNotFoundError: No module named 'paddle.nn',
**A**: paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0rc1的版本,安装方式为
```
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
```
**A**:设置的输入尺寸必须是32的倍数,否则在网络多次下采样和上采样后,feature map会产生1个像素的diff,从而导致elementwise_add时报shape不匹配的错误。 #### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
**A**:这个问题是glibc版本不足导致的,Paddle2.0rc1版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
如果您的环境不满足这个要求,或者使用的docker镜像为:
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,
2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`
或者访问[dockerhub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获得与您机器适配的镜像。
#### Q3.3.24: DB检测训练输入尺寸640,可以改大一些吗? #### Q3.1.24: PaddleOCR develop分支和dygraph分支的区别?
**A**:目前PaddleOCR有四个分支,分别是:
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
- dygraph:基于Paddle动态图开发的分支,未来将作为主要开发分支,运行要求使用Paddle2.0rc1版本,目前仍在开发中。
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
**A**: 不建议改大。检测模型训练输入尺寸是预处理中random crop后的尺寸,并非直接将原图进行resize,多数场景下这个尺寸并不小了,改大后可能反而并不合适,而且训练会变慢。另外,代码里可能有的地方参数按照预设输入尺寸适配的,改大后可能有隐藏风险 如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支
#### Q3.3.25: 识别模型训练时,loss能正常下降,但acc一直为0 **注意**:develop和dygraph分支要求的Paddle版本、本地环境有差别,请注意不同分支环境安装部分的差异。
**A**: 识别模型训练初期acc为0是正常的,多训一段时间指标就上来了。
#### Q3.4.24:DB模型能正确推理预测,但换成EAST或SAST模型时报错或结果不正确 #### Q3.4.25 : PaddleOCR模型Python端预测和C++预测结果不一致?
**A**:正常来说,python端预测和C++预测文本是一致的,如果预测结果差异较大,建议首先排查diff出现在检测模型还是识别模型,或者尝试换其他模型是否有类似的问题。其次,检查python端和C++端数据处理部分是否存在差异,建议保存环境,更新PaddleOCR代码再试下。如果更新代码或者更新代码都没能解决,建议在PaddleOCR群里或者issue中抛出您的问题。
**A**:使用EAST或SAST模型进行推理预测时,需要在命令中指定参数--det_algorithm="EAST" 或 --det_algorithm="SAST",使用DB时不用指定是因为该参数默认值是"DB":https://github.com/PaddlePaddle/PaddleOCR/blob/e7a708e9fdaf413ed7a14da8e4a7b4ac0b211e42/tools/infer/utility.py#L43
<a name="OCR精选10个问题"></a> <a name="OCR精选10个问题"></a>
## 【精选】OCR精选10个问题 ## 【精选】OCR精选10个问题
...@@ -238,18 +251,15 @@ ...@@ -238,18 +251,15 @@
(2)调大系统的[l2 dcay值](https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47) (2)调大系统的[l2 dcay值](https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47)
<a name="预测部署2"></a> #### Q2.3.8:请问对于图片中的密集文字,有什么好的处理办法吗?
### 预测部署
#### Q2.4.1:请问对于图片中的密集文字,有什么好的处理办法吗?
**A**:可以先试用预训练模型测试一下,例如DB+CRNN,判断下密集文字图片中是检测还是识别的问题,然后针对性的改善。还有一种是如果图象中密集文字较小,可以尝试增大图像分辨率,对图像进行一定范围内的拉伸,将文字稀疏化,提高识别效果。 **A**:可以先试用预训练模型测试一下,例如DB+CRNN,判断下密集文字图片中是检测还是识别的问题,然后针对性的改善。还有一种是如果图象中密集文字较小,可以尝试增大图像分辨率,对图像进行一定范围内的拉伸,将文字稀疏化,提高识别效果。
#### Q2.4.2:对于一些在识别时稍微模糊的文本,有没有一些图像增强的方式? #### Q2.3.9:对于一些在识别时稍微模糊的文本,有没有一些图像增强的方式?
**A**:在人类肉眼可以识别的前提下,可以考虑图像处理中的均值滤波、中值滤波或者高斯滤波等模糊算子尝试。也可以尝试从数据扩增扰动来强化模型鲁棒性,另外新的思路有对抗性训练和超分SR思路,可以尝试借鉴。但目前业界尚无普遍认可的最优方案,建议优先在数据采集阶段增加一些限制提升图片质量。 **A**:在人类肉眼可以识别的前提下,可以考虑图像处理中的均值滤波、中值滤波或者高斯滤波等模糊算子尝试。也可以尝试从数据扩增扰动来强化模型鲁棒性,另外新的思路有对抗性训练和超分SR思路,可以尝试借鉴。但目前业界尚无普遍认可的最优方案,建议优先在数据采集阶段增加一些限制提升图片质量。
#### Q2.4.3:对于特定文字检测,例如身份证只检测姓名,检测指定区域文字更好,还是检测全部区域再筛选更好? #### Q2.3.10:对于特定文字检测,例如身份证只检测姓名,检测指定区域文字更好,还是检测全部区域再筛选更好?
**A**:两个角度来说明一般检测全部区域再筛选更好。 **A**:两个角度来说明一般检测全部区域再筛选更好。
...@@ -257,11 +267,11 @@ ...@@ -257,11 +267,11 @@
(2)产品的需求可能是变化的,不排除后续对于模型需求变化的可能性(比如又需要增加一个字段),相比于训练模型,后处理的逻辑会更容易调整。 (2)产品的需求可能是变化的,不排除后续对于模型需求变化的可能性(比如又需要增加一个字段),相比于训练模型,后处理的逻辑会更容易调整。
#### Q2.4.4:对于小白如何快速入门中文OCR项目实践? #### Q2.3.11:对于小白如何快速入门中文OCR项目实践?
**A**:建议可以先了解OCR方向的基础知识,大概了解基础的检测和识别模型算法。然后在Github上可以查看OCR方向相关的repo。目前来看,从内容的完备性来看,PaddleOCR的中英文双语教程文档是有明显优势的,在数据集、模型训练、预测部署文档详实,可以快速入手。而且还有微信用户群答疑,非常适合学习实践。项目地址:[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) **A**:建议可以先了解OCR方向的基础知识,大概了解基础的检测和识别模型算法。然后在Github上可以查看OCR方向相关的repo。目前来看,从内容的完备性来看,PaddleOCR的中英文双语教程文档是有明显优势的,在数据集、模型训练、预测部署文档详实,可以快速入手。而且还有微信用户群答疑,非常适合学习实践。项目地址:[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
#### Q2.4.5:如何识别带空格的英文行文本图像? #### Q3.12:如何识别带空格的英文行文本图像?
**A**:空格识别可以考虑以下两种方案: **A**:空格识别可以考虑以下两种方案:
...@@ -269,22 +279,25 @@ ...@@ -269,22 +279,25 @@
(2)优化文本识别算法。在识别字典里面引入空格字符,然后在识别的训练数据中,如果用空行,进行标注。此外,合成数据时,通过拼接训练数据,生成含有空格的文本。 (2)优化文本识别算法。在识别字典里面引入空格字符,然后在识别的训练数据中,如果用空行,进行标注。此外,合成数据时,通过拼接训练数据,生成含有空格的文本。
#### Q2.4.6:中英文一起识别时也可以加空格字符来训练吗 #### Q2.3.13:中英文一起识别时也可以加空格字符来训练吗
**A**:中文识别可以加空格当做分隔符训练,具体的效果如何没法给出直接评判,根据实际业务数据训练来判断。 **A**:中文识别可以加空格当做分隔符训练,具体的效果如何没法给出直接评判,根据实际业务数据训练来判断。
#### Q2.4.7:低像素文字或者字号比较小的文字有什么超分辨率方法吗 #### Q2.3.14:低像素文字或者字号比较小的文字有什么超分辨率方法吗
**A**:超分辨率方法分为传统方法和基于深度学习的方法。基于深度学习的方法中,比较经典的有SRCNN,另外CVPR2020也有一篇超分辨率的工作可以参考文章:Unpaired Image Super-Resolution using Pseudo-Supervision,但是没有充分的实践验证过,需要看实际场景下的效果。 **A**:超分辨率方法分为传统方法和基于深度学习的方法。基于深度学习的方法中,比较经典的有SRCNN,另外CVPR2020也有一篇超分辨率的工作可以参考文章:Unpaired Image Super-Resolution using Pseudo-Supervision,但是没有充分的实践验证过,需要看实际场景下的效果。
#### Q2.4.8:表格识别有什么好的模型 或者论文推荐么 #### Q2.3.15:表格识别有什么好的模型 或者论文推荐么
**A**:表格目前学术界比较成熟的解决方案不多 ,可以尝试下分割的论文方案。 **A**:表格目前学术界比较成熟的解决方案不多 ,可以尝试下分割的论文方案。
#### Q2.4.9:弯曲文本有试过opencv的TPS进行弯曲校正吗? #### Q2.3.16:弯曲文本有试过opencv的TPS进行弯曲校正吗?
**A**:opencv的tps需要标出上下边界对应的点,这个点很难通过传统方法或者深度学习方法获取。PaddleOCR里StarNet网络中的tps模块实现了自动学点,自动校正,可以直接尝试这个。 **A**:opencv的tps需要标出上下边界对应的点,这个点很难通过传统方法或者深度学习方法获取。PaddleOCR里StarNet网络中的tps模块实现了自动学点,自动校正,可以直接尝试这个。
#### Q2.3.17: StyleText 合成数据效果不好?
**A**:StyleText模型生成的数据主要用于OCR识别模型的训练。PaddleOCR目前识别模型的输入为32 x N,因此当前版本模型主要适用高度为32的数据。
建议要合成的数据尺寸设置为32 x N。尺寸相差不多的数据也可以生成,尺寸很大或很小的数据效果确实不佳。
<a name="PaddleOCR实战问题"></a> <a name="PaddleOCR实战问题"></a>
...@@ -392,6 +405,34 @@ ...@@ -392,6 +405,34 @@
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。 **A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0rc1的版本,安装方式为
```
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
```
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
**A**:这个问题是glibc版本不足导致的,Paddle2.0rc1版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
如果您的环境不满足这个要求,或者使用的docker镜像为:
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`
或者访问[dockerhub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获得与您机器适配的镜像。
#### Q3.1.24: PaddleOCR develop分支和dygraph分支的区别?
**A** 目前PaddleOCR有四个分支,分别是:
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0rc1版本,目前仍在开发中。
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
**注意**:develop和dygraph分支要求的Paddle版本、本地环境有差别,请注意不同分支环境安装部分的差异。
<a name="数据集3"></a> <a name="数据集3"></a>
### 数据集 ### 数据集
...@@ -729,3 +770,9 @@ ps -axu | grep train.py | awk '{print $2}' | xargs kill -9 ...@@ -729,3 +770,9 @@ ps -axu | grep train.py | awk '{print $2}' | xargs kill -9
#### Q3.4.24:DB模型能正确推理预测,但换成EAST或SAST模型时报错或结果不正确 #### Q3.4.24:DB模型能正确推理预测,但换成EAST或SAST模型时报错或结果不正确
**A**:使用EAST或SAST模型进行推理预测时,需要在命令中指定参数--det_algorithm="EAST" 或 --det_algorithm="SAST",使用DB时不用指定是因为该参数默认值是"DB":https://github.com/PaddlePaddle/PaddleOCR/blob/e7a708e9fdaf413ed7a14da8e4a7b4ac0b211e42/tools/infer/utility.py#L43 **A**:使用EAST或SAST模型进行推理预测时,需要在命令中指定参数--det_algorithm="EAST" 或 --det_algorithm="SAST",使用DB时不用指定是因为该参数默认值是"DB":https://github.com/PaddlePaddle/PaddleOCR/blob/e7a708e9fdaf413ed7a14da8e4a7b4ac0b211e42/tools/infer/utility.py#L43
#### Q3.4.25 : PaddleOCR模型Python端预测和C++预测结果不一致?
正常来说,python端预测和C++预测文本是一致的,如果预测结果差异较大,
建议首先排查diff出现在检测模型还是识别模型,或者尝试换其他模型是否有类似的问题。
其次,检查python端和C++端数据处理部分是否存在差异,建议保存环境,更新PaddleOCR代码再试下。
如果更新代码或者更新代码都没能解决,建议在PaddleOCR微信群里或者issue中抛出您的问题。
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
### 1.文本检测算法 ### 1.文本检测算法
PaddleOCR开源的文本检测算法列表: PaddleOCR开源的文本检测算法列表:
- [x] DB([paper]( https://arxiv.org/abs/1911.08947) )(ppocr推荐) - [x] DB([paper]( https://arxiv.org/abs/1911.08947)) [2](ppocr推荐)
- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1]
- [x] SAST([paper](https://arxiv.org/abs/1908.05498)) - [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
在ICDAR2015文本检测公开数据集上,算法效果如下: 在ICDAR2015文本检测公开数据集上,算法效果如下:
...@@ -38,13 +38,13 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训 ...@@ -38,13 +38,13 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
### 2.文本识别算法 ### 2.文本识别算法
PaddleOCR基于动态图开源的文本识别算法列表: PaddleOCR基于动态图开源的文本识别算法列表:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717) )(ppocr推荐) - [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) coming soon - [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294)) coming soon - [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|-|-|-|-|-| |-|-|-|-|-|
......
...@@ -117,7 +117,7 @@ python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/ ...@@ -117,7 +117,7 @@ python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/
``` ```
# 预测分类结果 # 预测分类结果
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
``` ```
预测图片: 预测图片:
......
...@@ -120,16 +120,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat ...@@ -120,16 +120,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
测试单张图像的检测效果 测试单张图像的检测效果
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
``` ```
测试DB模型时,调整后处理阈值, 测试DB模型时,调整后处理阈值,
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
``` ```
测试文件夹下所有图像的检测效果 测试文件夹下所有图像的检测效果
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
``` ```
...@@ -245,7 +245,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -245,7 +245,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
超轻量中文识别模型推理,可以执行如下命令: 超轻量中文识别模型推理,可以执行如下命令:
``` ```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/" # 下载超轻量中文识别模型:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="ch_ppocr_mobile_v2.0_rec_infer"
``` ```
![](../imgs_words/ch/word_4.jpg) ![](../imgs_words/ch/word_4.jpg)
...@@ -266,7 +269,6 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153) ...@@ -266,7 +269,6 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153)
``` ```
python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn
``` ```
CRNN 文本识别模型推理,可以执行如下命令: CRNN 文本识别模型推理,可以执行如下命令:
...@@ -327,7 +329,10 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904) ...@@ -327,7 +329,10 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
方向分类模型推理,可以执行如下命令: 方向分类模型推理,可以执行如下命令:
``` ```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/" # 下载超轻量中文方向分类器模型:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="ch_ppocr_mobile_v2.0_cls_infer"
``` ```
![](../imgs_words/ch/word_1.jpg) ![](../imgs_words/ch/word_1.jpg)
......
...@@ -324,7 +324,6 @@ Eval: ...@@ -324,7 +324,6 @@ Eval:
评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。 评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
``` ```
# GPU 评估, Global.checkpoints 为待测权重 # GPU 评估, Global.checkpoints 为待测权重
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
...@@ -342,7 +341,7 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec ...@@ -342,7 +341,7 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
``` ```
# 预测英文结果 # 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/en/word_1.png
``` ```
预测图片: 预测图片:
...@@ -361,7 +360,7 @@ infer_img: doc/imgs_words/en/word_1.png ...@@ -361,7 +360,7 @@ infer_img: doc/imgs_words/en/word_1.png
``` ```
# 预测中文结果 # 预测中文结果
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
``` ```
预测图片: 预测图片:
......
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
} }
2. DB: 2. DB:
@article{liao2019real, @inproceedings{liao2020real,
title={Real-time Scene Text Detection with Differentiable Binarization}, title={Real-Time Scene Text Detection with Differentiable Binarization.},
author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang}, author={Liao, Minghui and Wan, Zhaoyi and Yao, Cong and Chen, Kai and Bai, Xiang},
journal={arXiv preprint arXiv:1911.08947}, booktitle={AAAI},
year={2019} pages={11474--11481},
year={2020}
} }
3. DTRB: 3. DTRB:
...@@ -37,10 +38,11 @@ ...@@ -37,10 +38,11 @@
} }
5. SRN: 5. SRN:
@article{yu2020towards, @inproceedings{yu2020towards,
title={Towards Accurate Scene Text Recognition with Semantic Reasoning Networks}, title={Towards accurate scene text recognition with semantic reasoning networks},
author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Han, Junyu and Liu, Jingtuo and Ding, Errui}, author={Yu, Deli and Li, Xuan and Zhang, Chengquan and Liu, Tao and Han, Junyu and Liu, Jingtuo and Ding, Errui},
journal={arXiv preprint arXiv:2003.12294}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={12113--12122},
year={2020} year={2020}
} }
...@@ -52,4 +54,62 @@ ...@@ -52,4 +54,62 @@
pages={9086--9095}, pages={9086--9095},
year={2019} year={2019}
} }
7. CRNN:
@article{shi2016end,
title={An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition},
author={Shi, Baoguang and Bai, Xiang and Yao, Cong},
journal={IEEE transactions on pattern analysis and machine intelligence},
volume={39},
number={11},
pages={2298--2304},
year={2016},
publisher={IEEE}
}
8. FPGM:
@inproceedings{he2019filter,
title={Filter pruning via geometric median for deep convolutional neural networks acceleration},
author={He, Yang and Liu, Ping and Wang, Ziwei and Hu, Zhilan and Yang, Yi},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4340--4349},
year={2019}
}
9. PACT:
@article{choi2018pact,
title={Pact: Parameterized clipping activation for quantized neural networks},
author={Choi, Jungwook and Wang, Zhuo and Venkataramani, Swagath and Chuang, Pierce I-Jen and Srinivasan, Vijayalakshmi and Gopalakrishnan, Kailash},
journal={arXiv preprint arXiv:1805.06085},
year={2018}
}
10.Rosetta
@inproceedings{borisyuk2018rosetta,
title={Rosetta: Large scale system for text detection and recognition in images},
author={Borisyuk, Fedor and Gordo, Albert and Sivakumar, Viswanath},
booktitle={Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery \& Data Mining},
pages={71--79},
year={2018}
}
11.STAR-Net
@inproceedings{liu2016star,
title={STAR-Net: A SpaTial Attention Residue Network for Scene Text Recognition.},
author={Liu, Wei and Chen, Chaofeng and Wong, Kwan-Yee K and Su, Zhizhong and Han, Junyu},
booktitle={BMVC},
volume={2},
pages={7},
year={2016}
}
12.RARE
@inproceedings{shi2016robust,
title={Robust scene text recognition with automatic rectification},
author={Shi, Baoguang and Wang, Xinggang and Lyu, Pengyuan and Yao, Cong and Bai, Xiang},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={4168--4176},
year={2016}
}
``` ```
...@@ -11,9 +11,9 @@ This tutorial lists the text detection algorithms and text recognition algorithm ...@@ -11,9 +11,9 @@ This tutorial lists the text detection algorithms and text recognition algorithm
### 1. Text Detection Algorithm ### 1. Text Detection Algorithm
PaddleOCR open source text detection algorithms list: PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] EAST([paper](https://arxiv.org/abs/1704.03155))[2]
- [x] DB([paper](https://arxiv.org/abs/1911.08947)) - [x] DB([paper](https://arxiv.org/abs/1911.08947))[1]
- [x] SAST([paper](https://arxiv.org/abs/1908.05498) )(Baidu Self-Research) - [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
On the ICDAR2015 dataset, the text detection result is as follows: On the ICDAR2015 dataset, the text detection result is as follows:
...@@ -39,11 +39,11 @@ For the training guide and use of PaddleOCR text detection algorithms, please re ...@@ -39,11 +39,11 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
### 2. Text Recognition Algorithm ### 2. Text Recognition Algorithm
PaddleOCR open-source text recognition algorithms list: PaddleOCR open-source text recognition algorithms list:
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717)) - [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) coming soon - [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] coming soon
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294) )(Baidu Self-Research) coming soon - [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
......
...@@ -119,7 +119,7 @@ Use `Global.infer_img` to specify the path of the predicted picture or folder, a ...@@ -119,7 +119,7 @@ Use `Global.infer_img` to specify the path of the predicted picture or folder, a
``` ```
# Predict English results # Predict English results
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_10.png python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words_en/word_10.png
``` ```
Input image: Input image:
......
...@@ -113,16 +113,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat ...@@ -113,16 +113,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
Test the detection result on a single image: Test the detection result on a single image:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
``` ```
When testing the DB model, adjust the post-processing threshold: When testing the DB model, adjust the post-processing threshold:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
``` ```
Test the detection result on all images in the folder: Test the detection result on all images in the folder:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
``` ```
...@@ -255,15 +255,18 @@ The following will introduce the lightweight Chinese recognition model inference ...@@ -255,15 +255,18 @@ The following will introduce the lightweight Chinese recognition model inference
For lightweight Chinese recognition model inference, you can execute the following commands: For lightweight Chinese recognition model inference, you can execute the following commands:
``` ```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/" # download CRNN text recognition inference model
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_10.png" --rec_model_dir="ch_ppocr_mobile_v2.0_rec_infer"
``` ```
![](../imgs_words/ch/word_4.jpg) ![](../imgs_words_en/word_10.png)
After executing the command, the prediction results (recognized text and score) of the above image will be printed on the screen. After executing the command, the prediction results (recognized text and score) of the above image will be printed on the screen.
```bash ```bash
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153) Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658)
``` ```
<a name="CTC-BASED_RECOGNITION"></a> <a name="CTC-BASED_RECOGNITION"></a>
...@@ -339,7 +342,12 @@ For angle classification model inference, you can execute the following commands ...@@ -339,7 +342,12 @@ For angle classification model inference, you can execute the following commands
``` ```
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="./inference/cls/" python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="./inference/cls/"
``` ```
```
# download text angle class inference model:
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="ch_ppocr_mobile_v2.0_cls_infer"
```
![](../imgs_words_en/word_10.png) ![](../imgs_words_en/word_10.png)
After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen. After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen.
......
...@@ -317,11 +317,11 @@ Eval: ...@@ -317,11 +317,11 @@ Eval:
<a name="EVALUATION"></a> <a name="EVALUATION"></a>
### EVALUATION ### EVALUATION
The evaluation data set can be modified via `configs/rec/rec_icdar15_reader.yml` setting of `label_file_path` in EvalReader. The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/rec/rec_icdar15_train.yml` file.
``` ```
# GPU evaluation, Global.checkpoints is the weight to be tested # GPU evaluation, Global.checkpoints is the weight to be tested
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_reader.yml -o Global.checkpoints={path/to/weights}/best_accuracy python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
``` ```
<a name="PREDICTION"></a> <a name="PREDICTION"></a>
...@@ -336,7 +336,7 @@ The default prediction picture is stored in `infer_img`, and the weight is speci ...@@ -336,7 +336,7 @@ The default prediction picture is stored in `infer_img`, and the weight is speci
``` ```
# Predict English results # Predict English results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/en/word_1.jpg
``` ```
Input image: Input image:
...@@ -354,7 +354,7 @@ The configuration file used for prediction must be consistent with the training. ...@@ -354,7 +354,7 @@ The configuration file used for prediction must be consistent with the training.
``` ```
# Predict Chinese results # Predict Chinese results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/ch/word_1.jpg python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/ch/word_1.jpg
``` ```
Input image: Input image:
......
...@@ -262,8 +262,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -262,8 +262,8 @@ class PaddleOCR(predict_system.TextSystem):
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
sys.exit(0) sys.exit(0)
postprocess_params.rec_char_dict_path = Path( postprocess_params.rec_char_dict_path = str(
__file__).parent / postprocess_params.rec_char_dict_path Path(__file__).parent / postprocess_params.rec_char_dict_path)
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
......
...@@ -45,7 +45,6 @@ class BalanceLoss(nn.Layer): ...@@ -45,7 +45,6 @@ class BalanceLoss(nn.Layer):
self.balance_loss = balance_loss self.balance_loss = balance_loss
self.main_loss_type = main_loss_type self.main_loss_type = main_loss_type
self.negative_ratio = negative_ratio self.negative_ratio = negative_ratio
self.main_loss_type = main_loss_type
self.return_origin = return_origin self.return_origin = return_origin
self.eps = eps self.eps = eps
......
...@@ -19,7 +19,6 @@ from __future__ import print_function ...@@ -19,7 +19,6 @@ from __future__ import print_function
import paddle import paddle
from paddle import nn from paddle import nn
from .det_basic_loss import DiceLoss from .det_basic_loss import DiceLoss
import paddle.fluid as fluid
import numpy as np import numpy as np
...@@ -27,9 +26,7 @@ class SASTLoss(nn.Layer): ...@@ -27,9 +26,7 @@ class SASTLoss(nn.Layer):
""" """
""" """
def __init__(self, def __init__(self, eps=1e-6, **kwargs):
eps=1e-6,
**kwargs):
super(SASTLoss, self).__init__() super(SASTLoss, self).__init__()
self.dice_loss = DiceLoss(eps=eps) self.dice_loss = DiceLoss(eps=eps)
...@@ -53,10 +50,12 @@ class SASTLoss(nn.Layer): ...@@ -53,10 +50,12 @@ class SASTLoss(nn.Layer):
score_loss = 1.0 - 2 * intersection / (union + 1e-5) score_loss = 1.0 - 2 * intersection / (union + 1e-5)
#border loss #border loss
l_border_split, l_border_norm = paddle.split(l_border, num_or_sections=[4, 1], axis=1) l_border_split, l_border_norm = paddle.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border f_border_split = f_border
border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1]) border_ex_shape = l_border_norm.shape * np.array([1, 4, 1, 1])
l_border_norm_split = paddle.expand(x=l_border_norm, shape=border_ex_shape) l_border_norm_split = paddle.expand(
x=l_border_norm, shape=border_ex_shape)
l_border_score = paddle.expand(x=l_score, shape=border_ex_shape) l_border_score = paddle.expand(x=l_score, shape=border_ex_shape)
l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape) l_border_mask = paddle.expand(x=l_mask, shape=border_ex_shape)
...@@ -72,7 +71,8 @@ class SASTLoss(nn.Layer): ...@@ -72,7 +71,8 @@ class SASTLoss(nn.Layer):
(paddle.sum(l_border_score * l_border_mask) + 1e-5) (paddle.sum(l_border_score * l_border_mask) + 1e-5)
#tvo_loss #tvo_loss
l_tvo_split, l_tvo_norm = paddle.split(l_tvo, num_or_sections=[8, 1], axis=1) l_tvo_split, l_tvo_norm = paddle.split(
l_tvo, num_or_sections=[8, 1], axis=1)
f_tvo_split = f_tvo f_tvo_split = f_tvo
tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1]) tvo_ex_shape = l_tvo_norm.shape * np.array([1, 8, 1, 1])
l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape) l_tvo_norm_split = paddle.expand(x=l_tvo_norm, shape=tvo_ex_shape)
...@@ -91,7 +91,8 @@ class SASTLoss(nn.Layer): ...@@ -91,7 +91,8 @@ class SASTLoss(nn.Layer):
(paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5) (paddle.sum(l_tvo_score * l_tvo_mask) + 1e-5)
#tco_loss #tco_loss
l_tco_split, l_tco_norm = paddle.split(l_tco, num_or_sections=[2, 1], axis=1) l_tco_split, l_tco_norm = paddle.split(
l_tco, num_or_sections=[2, 1], axis=1)
f_tco_split = f_tco f_tco_split = f_tco
tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1]) tco_ex_shape = l_tco_norm.shape * np.array([1, 2, 1, 1])
l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape) l_tco_norm_split = paddle.expand(x=l_tco_norm, shape=tco_ex_shape)
...@@ -109,7 +110,6 @@ class SASTLoss(nn.Layer): ...@@ -109,7 +110,6 @@ class SASTLoss(nn.Layer):
tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \ tco_loss = paddle.sum(tco_out_loss * l_tco_score * l_tco_mask) / \
(paddle.sum(l_tco_score * l_tco_mask) + 1e-5) (paddle.sum(l_tco_score * l_tco_mask) + 1e-5)
# total loss # total loss
tvo_lw, tco_lw = 1.5, 1.5 tvo_lw, tco_lw = 1.5, 1.5
score_lw, border_lw = 1.0, 1.0 score_lw, border_lw = 1.0, 1.0
......
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.1', version='2.0.2',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -24,7 +24,6 @@ import numpy as np ...@@ -24,7 +24,6 @@ import numpy as np
import math import math
import time import time
import traceback import traceback
import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -39,7 +38,6 @@ class TextClassifier(object): ...@@ -39,7 +38,6 @@ class TextClassifier(object):
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.cls_batch_num self.cls_batch_num = args.cls_batch_num
self.cls_thresh = args.cls_thresh self.cls_thresh = args.cls_thresh
self.use_zero_copy_run = args.use_zero_copy_run
postprocess_params = { postprocess_params = {
'name': 'ClsPostProcess', 'name': 'ClsPostProcess',
"label_list": args.label_list, "label_list": args.label_list,
...@@ -99,12 +97,8 @@ class TextClassifier(object): ...@@ -99,12 +97,8 @@ class TextClassifier(object):
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
cls_result = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
...@@ -143,10 +137,11 @@ def main(args): ...@@ -143,10 +137,11 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
ino])) cls_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())
...@@ -22,7 +22,6 @@ import cv2 ...@@ -22,7 +22,6 @@ import cv2
import numpy as np import numpy as np
import time import time
import sys import sys
import paddle
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
...@@ -35,8 +34,8 @@ logger = get_logger() ...@@ -35,8 +34,8 @@ logger = get_logger()
class TextDetector(object): class TextDetector(object):
def __init__(self, args): def __init__(self, args):
self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': { 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len, 'limit_side_len': args.det_limit_side_len,
...@@ -70,6 +69,11 @@ class TextDetector(object): ...@@ -70,6 +69,11 @@ class TextDetector(object):
postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh
postprocess_params["nms_thresh"] = args.det_east_nms_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh
elif self.det_algorithm == "SAST": elif self.det_algorithm == "SAST":
pre_process_list[0] = {
'DetResizeForTest': {
'resize_long': args.det_limit_side_len
}
}
postprocess_params['name'] = 'SASTPostProcess' postprocess_params['name'] = 'SASTPostProcess'
postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
...@@ -157,12 +161,8 @@ class TextDetector(object): ...@@ -157,12 +161,8 @@ class TextDetector(object):
img = img.copy() img = img.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(img) self.input_tensor.copy_from_cpu(img)
self.predictor.zero_copy_run() self.predictor.run()
else:
im = paddle.fluid.core.PaddleTensor(img)
self.predictor.run([im])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
import math import math
import time import time
import traceback import traceback
import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
...@@ -39,7 +38,6 @@ class TextRecognizer(object): ...@@ -39,7 +38,6 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
postprocess_params = { postprocess_params = {
'name': 'CTCLabelDecode', 'name': 'CTCLabelDecode',
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -101,12 +99,8 @@ class TextRecognizer(object): ...@@ -101,12 +99,8 @@ class TextRecognizer(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
...@@ -145,8 +139,8 @@ def main(args): ...@@ -145,8 +139,8 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
ino])) rec_res[ino]))
logger.info("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
......
...@@ -20,8 +20,7 @@ import numpy as np ...@@ -20,8 +20,7 @@ import numpy as np
import json import json
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import math import math
from paddle.fluid.core import AnalysisConfig from paddle import inference
from paddle.fluid.core import create_paddle_predictor
def parse_args(): def parse_args():
...@@ -33,6 +32,7 @@ def parse_args(): ...@@ -33,6 +32,7 @@ def parse_args():
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector # params for text detector
...@@ -46,7 +46,7 @@ def parse_args(): ...@@ -46,7 +46,7 @@ def parse_args():
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10)
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -62,7 +62,7 @@ def parse_args(): ...@@ -62,7 +62,7 @@ def parse_args():
parser.add_argument("--rec_model_dir", type=str) parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6) parser.add_argument("--rec_batch_num", type=int, default=1)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument( parser.add_argument(
"--rec_char_dict_path", "--rec_char_dict_path",
...@@ -78,12 +78,10 @@ def parse_args(): ...@@ -78,12 +78,10 @@ def parse_args():
parser.add_argument("--cls_model_dir", type=str) parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30) parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
return parser.parse_args() return parser.parse_args()
...@@ -109,10 +107,15 @@ def create_predictor(args, mode, logger): ...@@ -109,10 +107,15 @@ def create_predictor(args, mode, logger):
logger.info("not find params file path {}".format(params_file_path)) logger.info("not find params file path {}".format(params_file_path))
sys.exit(0) sys.exit(0)
config = AnalysisConfig(model_file_path, params_file_path) config = inference.Config(model_file_path, params_file_path)
if args.use_gpu: if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0) config.enable_use_gpu(args.gpu_mem, 0)
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=inference.PrecisionType.Half
if args.use_fp16 else inference.PrecisionType.Float32,
max_batch_size=args.max_batch_size)
else: else:
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(6) config.set_cpu_math_library_num_threads(6)
...@@ -124,20 +127,18 @@ def create_predictor(args, mode, logger): ...@@ -124,20 +127,18 @@ def create_predictor(args, mode, logger):
# config.enable_memory_optim() # config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) # create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: for name in input_names:
input_tensor = predictor.get_input_tensor(name) input_tensor = predictor.get_input_handle(name)
output_names = predictor.get_output_names() output_names = predictor.get_output_names()
output_tensors = [] output_tensors = []
for output_name in output_names: for output_name in output_names:
output_tensor = predictor.get_output_tensor(output_name) output_tensor = predictor.get_output_handle(output_name)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
......
...@@ -131,7 +131,7 @@ def check_gpu(use_gpu): ...@@ -131,7 +131,7 @@ def check_gpu(use_gpu):
"model on CPU" "model on CPU"
try: try:
if use_gpu and not paddle.fluid.is_compiled_with_cuda(): if use_gpu and not paddle.is_compiled_with_cuda():
print(err) print(err)
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
...@@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return metirc return metirc
def preprocess(): def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
...@@ -350,15 +350,17 @@ def preprocess(): ...@@ -350,15 +350,17 @@ def preprocess():
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config # save_config
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True) os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger( log_file = '{}/train.log'.format(save_model_dir)
name='root', log_file='{}/train.log'.format(save_model_dir)) else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
...@@ -110,6 +110,6 @@ def test_reader(config, device, logger): ...@@ -110,6 +110,6 @@ def test_reader(config, device, logger):
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess() config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger) # test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册