提交 85a98fe2 编写于 作者: T tink2123

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

...@@ -24,4 +24,8 @@ output/ ...@@ -24,4 +24,8 @@ output/
build/ build/
dist/ dist/
paddleocr.egg-info/ paddleocr.egg-info/
\ No newline at end of file /deploy/android_demo/app/OpenCV/
/deploy/android_demo/app/PaddleLite/
/deploy/android_demo/app/.cxx/
/deploy/android_demo/app/cache/
...@@ -4,4 +4,5 @@ include README.md ...@@ -4,4 +4,5 @@ include README.md
recursive-include ppocr/utils *.txt utility.py logging.py recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/data/ *.py recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py recursive-include tools/infer *.py
\ No newline at end of file recursive-include ppocr/utils/e2e_utils/ *.py
\ No newline at end of file
...@@ -147,6 +147,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -147,6 +147,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.itemsToShapesbox = {} self.itemsToShapesbox = {}
self.shapesToItemsbox = {} self.shapesToItemsbox = {}
self.prevLabelText = getStr('tempLabel') self.prevLabelText = getStr('tempLabel')
self.noLabelText = getStr('nullLabel')
self.model = 'paddle' self.model = 'paddle'
self.PPreader = None self.PPreader = None
self.autoSaveNum = 5 self.autoSaveNum = 5
...@@ -1020,7 +1021,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1020,7 +1021,7 @@ class MainWindow(QMainWindow, WindowMixin):
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points])) item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
self.updateComboBox() self.updateComboBox()
def updateComboBox(self): # TODO:貌似没用 def updateComboBox(self):
# Get the unique labels and add them to the Combobox. # Get the unique labels and add them to the Combobox.
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())] itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
...@@ -1040,7 +1041,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1040,7 +1041,7 @@ class MainWindow(QMainWindow, WindowMixin):
return dict(label=s.label, # str return dict(label=s.label, # str
line_color=s.line_color.getRgb(), line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(), fill_color=s.fill_color.getRgb(),
points=[(p.x(), p.y()) for p in s.points], # QPonitF points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
# add chris # add chris
difficult=s.difficult) # bool difficult=s.difficult) # bool
...@@ -1069,7 +1070,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1069,7 +1070,7 @@ class MainWindow(QMainWindow, WindowMixin):
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath)) # print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
return True return True
except: except:
self.errorMessage(u'Error saving label data') self.errorMessage(u'Error saving label data', u'Error saving label data')
return False return False
def copySelectedShape(self): def copySelectedShape(self):
...@@ -1802,10 +1803,14 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1802,10 +1803,14 @@ class MainWindow(QMainWindow, WindowMixin):
result.insert(0, box) result.insert(0, box)
print('result in reRec is ', result) print('result in reRec is ', result)
self.result_dic.append(result) self.result_dic.append(result)
if result[1][0] == shape.label: else:
print('label no change') print('Can not recognise the box')
else: self.result_dic.append([box,(self.noLabelText,0)])
rec_flag += 1
if self.noLabelText == shape.label or result[1][0] == shape.label:
print('label no change')
else:
rec_flag += 1
if len(self.result_dic) > 0 and rec_flag > 0: if len(self.result_dic) > 0 and rec_flag > 0:
self.saveFile(mode='Auto') self.saveFile(mode='Auto')
...@@ -1836,9 +1841,14 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -1836,9 +1841,14 @@ class MainWindow(QMainWindow, WindowMixin):
print('label no change') print('label no change')
else: else:
shape.label = result[1][0] shape.label = result[1][0]
self.singleLabel(shape) else:
self.setDirty() print('Can not recognise the box')
print(box) if self.noLabelText == shape.label:
print('label no change')
else:
shape.label = self.noLabelText
self.singleLabel(shape)
self.setDirty()
def autolcm(self): def autolcm(self):
vbox = QVBoxLayout() vbox = QVBoxLayout()
......
...@@ -29,9 +29,7 @@ PaddleOCR models has been built in PPOCRLabel, please refer to [PaddleOCR instal ...@@ -29,9 +29,7 @@ PaddleOCR models has been built in PPOCRLabel, please refer to [PaddleOCR instal
### 2. Install PPOCRLabel ### 2. Install PPOCRLabel
#### Windows + Anaconda #### Windows
Download and install [Anaconda](https://www.anaconda.com/download/#download) (Python 3+)
``` ```
pip install pyqt5 pip install pyqt5
......
...@@ -31,7 +31,7 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P ...@@ -31,7 +31,7 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
PPOCRLabel内置PaddleOCR模型,故请参考[PaddleOCR安装文档](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/installation.md)准备好PaddleOCR,并确保PaddleOCR安装成功。 PPOCRLabel内置PaddleOCR模型,故请参考[PaddleOCR安装文档](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/installation.md)准备好PaddleOCR,并确保PaddleOCR安装成功。
### 2. 安装PPOCRLabel ### 2. 安装PPOCRLabel
#### Windows + Anaconda #### Windows
``` ```
pip install pyqt5 pip install pyqt5
......
...@@ -45,7 +45,7 @@ class Canvas(QWidget): ...@@ -45,7 +45,7 @@ class Canvas(QWidget):
CREATE, EDIT = list(range(2)) CREATE, EDIT = list(range(2))
_fill_drawing = False # draw shadows _fill_drawing = False # draw shadows
epsilon = 11.0 epsilon = 5.0
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Canvas, self).__init__(*args, **kwargs) super(Canvas, self).__init__(*args, **kwargs)
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -124,6 +124,15 @@ def natural_sort(list, key=lambda s:s): ...@@ -124,6 +124,15 @@ def natural_sort(list, key=lambda s:s):
def get_rotate_crop_image(img, points): def get_rotate_crop_image(img, points):
# Use Green's theory to judge clockwise or counterclockwise
# author: biyanhua
d = 0.0
for index in range(-1, 3):
d += -0.5 * (points[index + 1][1] + points[index][1]) * (
points[index + 1][0] - points[index][0])
if d < 0: # counterclockwise
tmp = np.array(points)
points[1], points[3] = tmp[3], tmp[1]
try: try:
img_crop_width = int( img_crop_width = int(
......
...@@ -87,6 +87,7 @@ creatPolygon=四点标注 ...@@ -87,6 +87,7 @@ creatPolygon=四点标注
drawSquares=正方形标注 drawSquares=正方形标注
saveRec=保存识别结果 saveRec=保存识别结果
tempLabel=待识别 tempLabel=待识别
nullLabel=无法识别
steps=操作步骤 steps=操作步骤
choseModelLg=选择模型语言 choseModelLg=选择模型语言
cancel=取消 cancel=取消
......
...@@ -77,7 +77,7 @@ IR=Image Resize ...@@ -77,7 +77,7 @@ IR=Image Resize
autoRecognition=Auto Recognition autoRecognition=Auto Recognition
reRecognition=Re-recognition reRecognition=Re-recognition
mfile=File mfile=File
medit=Eidt medit=Edit
mview=View mview=View
mhelp=Help mhelp=Help
iconList=Icon List iconList=Icon List
...@@ -87,6 +87,7 @@ creatPolygon=Create Quadrilateral ...@@ -87,6 +87,7 @@ creatPolygon=Create Quadrilateral
drawSquares=Draw Squares drawSquares=Draw Squares
saveRec=Save Recognition Result saveRec=Save Recognition Result
tempLabel=TEMPORARY tempLabel=TEMPORARY
nullLabel=NULL
steps=Steps steps=Steps
choseModelLg=Choose Model Language choseModelLg=Choose Model Language
cancel=Cancel cancel=Cancel
......
...@@ -32,7 +32,8 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm ...@@ -32,7 +32,8 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm
<div align="center"> <div align="center">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800"> <img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/00018069.jpg" width="800"> <img src="doc/imgs_results/multi_lang/img_01.jpg" width="800">
<img src="doc/imgs_results/multi_lang/img_02.jpg" width="800">
</div> </div>
The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md). The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md).
......
...@@ -8,9 +8,9 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -8,9 +8,9 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 静态图版本:develop分支 - 静态图版本:develop分支
**近期更新** **近期更新**
- 2021.4.8 release 2.1版本,新增AAAI 2021论文[端到端识别算法PGNet](./doc/doc_ch/pgnet.md)开源,[多语言模型](./doc/doc_ch/multi_languages.md)支持种类增加到80+。
- 2021.2.1 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数162个,每周一都会更新,欢迎大家持续关注。 - 2021.2.1 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数162个,每周一都会更新,欢迎大家持续关注。
- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802) - 2021.1.21 更新多语言识别模型,目前支持语种超过27种,包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。 - 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。 - 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941 - 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
...@@ -74,11 +74,13 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -74,11 +74,13 @@ PaddleOCR同时支持动态图与静态图两种编程范式
## 文档教程 ## 文档教程
- [快速安装](./doc/doc_ch/installation.md) - [快速安装](./doc/doc_ch/installation.md)
- [中文OCR模型快速使用](./doc/doc_ch/quickstart.md) - [中文OCR模型快速使用](./doc/doc_ch/quickstart.md)
- [多语言OCR模型快速使用](./doc/doc_ch/multi_languages.md)
- [代码组织结构](./doc/doc_ch/tree.md) - [代码组织结构](./doc/doc_ch/tree.md)
- 算法介绍 - 算法介绍
- [文本检测](./doc/doc_ch/algorithm_overview.md) - [文本检测](./doc/doc_ch/algorithm_overview.md)
- [文本识别](./doc/doc_ch/algorithm_overview.md) - [文本识别](./doc/doc_ch/algorithm_overview.md)
- [PP-OCR Pipline](#PP-OCR) - [PP-OCR Pipeline](#PP-OCR)
- [端到端PGNet算法](./doc/doc_ch/pgnet.md)
- 模型训练/评估 - 模型训练/评估
- [文本检测](./doc/doc_ch/detection.md) - [文本检测](./doc/doc_ch/detection.md)
- [文本识别](./doc/doc_ch/recognition.md) - [文本识别](./doc/doc_ch/recognition.md)
...@@ -112,7 +114,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式 ...@@ -112,7 +114,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
<a name="PP-OCR"></a> <a name="PP-OCR"></a>
## PP-OCR Pipline ## PP-OCR Pipeline
<div align="center"> <div align="center">
<img src="./doc/ppocr_framework.png" width="800"> <img src="./doc/ppocr_framework.png" width="800">
</div> </div>
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000] eval_batch_step: [3000, 2000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: checkpoints:
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000] eval_batch_step: [3000, 2000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet18_vd_pretrained pretrained_model: ./pretrain_models/ResNet18_vd_pretrained
checkpoints: checkpoints:
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: checkpoints:
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1000 save_epoch_step: 1000
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: checkpoints:
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1200 save_epoch_step: 1200
# evaluation is run every 2000 iterations # evaluation is run every 2000 iterations
eval_batch_step: [0,2000] eval_batch_step: [0,2000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
checkpoints: checkpoints:
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1000 save_epoch_step: 1000
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/
checkpoints: checkpoints:
......
...@@ -7,19 +7,15 @@ Global: ...@@ -7,19 +7,15 @@ Global:
save_epoch_step: 1000 save_epoch_step: 1000
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
Architecture: Architecture:
model_type: det model_type: det
algorithm: SAST algorithm: SAST
......
...@@ -7,11 +7,6 @@ Global: ...@@ -7,11 +7,6 @@ Global:
save_epoch_step: 1000 save_epoch_step: 1000
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [4000, 5000] eval_batch_step: [4000, 5000]
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
# from static branch, load_static_weights must be set as True.
# 2. If you want to finetune the pretrained models we provide in the docs,
# you should set load_static_weights as False.
load_static_weights: True
cal_metric_during_train: False cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
checkpoints: checkpoints:
......
Global:
use_gpu: True
epoch_num: 600
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/pgnet_r50_vd_totaltext/
save_epoch_step: 10
# evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img:
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/ic15_dict.txt
character_type: EN
max_text_length: 50 # the max length in seq
max_text_nums: 30 # the max seq nums in a pic
tcl_len: 64
Architecture:
model_type: e2e
algorithm: PGNet
Transform:
Backbone:
name: ResNet
layers: 50
Neck:
name: PGFPN
Head:
name: PGHead
Loss:
name: PGLoss
tcl_bs: 64
max_text_length: 50 # the same as Global: max_text_length
max_text_nums: 30 # the same as Global:max_text_nums
pad_num: 36 # the length of dict for pad
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: PGPostProcess
score_thresh: 0.5
mode: fast # fast or slow two ways
Metric:
name: E2EMetric
mode: A # two ways for eval, A: label from txt, B: label from gt_mat
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
Train:
dataset:
name: PGDataSet
data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/train.txt]
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- E2ELabelEncodeTrain:
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
min_text_size: 4
max_text_size: 512
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
shuffle: True
drop_last: True
batch_size_per_card: 14
num_workers: 16
Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/test.txt]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- E2ELabelEncodeTest:
- E2EResizeForTest:
max_side_len: 768
- NormalizeImage:
scale: 1./255.
mean: [ 0.485, 0.456, 0.406 ]
std: [ 0.229, 0.224, 0.225 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'texts', 'ignore_tags', 'img_id']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: True
save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: True
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer: Optimizer:
......
...@@ -19,21 +19,56 @@ import logging ...@@ -19,21 +19,56 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
support_list = { support_list = {
'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic', 'it': 'italian',
'ta':'tamil', 'ug':'uyghur', 'fa':'persian', 'ur':'urdu', 'rs':'serbian latin', 'xi': 'spanish',
'oc':'occitan', 'rsc':'serbian cyrillic', 'bg':'bulgarian', 'uk':'ukranian', 'be':'belarusian', 'pu': 'portuguese',
'te':'telugu', 'ka':'kannada', 'chinese_cht':'chinese tradition','hi':'hindi','mr':'marathi', 'ru': 'russian',
'ne':'nepali', 'ar': 'arabic',
'ta': 'tamil',
'ug': 'uyghur',
'fa': 'persian',
'ur': 'urdu',
'rs': 'serbian latin',
'oc': 'occitan',
'rsc': 'serbian cyrillic',
'bg': 'bulgarian',
'uk': 'ukranian',
'be': 'belarusian',
'te': 'telugu',
'ka': 'kannada',
'chinese_cht': 'chinese tradition',
'hi': 'hindi',
'mr': 'marathi',
'ne': 'nepali',
} }
assert(
os.path.isfile("./rec_multi_language_lite_train.yml") latin_lang = [
),"Loss basic configuration file rec_multi_language_lite_train.yml.\ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
'sw', 'tl', 'tr', 'uz', 'vi', 'latin'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
'sa', 'bgc', 'devanagari'
]
multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang
assert (os.path.isfile("./rec_multi_language_lite_train.yml")
), "Loss basic configuration file rec_multi_language_lite_train.yml.\
You can download it from \ You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/" https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader) global_config = yaml.load(
open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../")) project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
def __init__(self): def __init__(self):
super(ArgsParser, self).__init__( super(ArgsParser, self).__init__(
...@@ -41,15 +76,30 @@ class ArgsParser(ArgumentParser): ...@@ -41,15 +76,30 @@ class ArgsParser(ArgumentParser):
self.add_argument( self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options") "-o", "--opt", nargs='+', help="set configuration options")
self.add_argument( self.add_argument(
"-l", "--language", nargs='+', help="set language type, support {}".format(support_list)) "-l",
"--language",
nargs='+',
help="set language type, support {}".format(support_list))
self.add_argument( self.add_argument(
"--train",type=str,help="you can use this command to change the train dataset default path") "--train",
type=str,
help="you can use this command to change the train dataset default path"
)
self.add_argument( self.add_argument(
"--val",type=str,help="you can use this command to change the eval dataset default path") "--val",
type=str,
help="you can use this command to change the eval dataset default path"
)
self.add_argument( self.add_argument(
"--dict",type=str,help="you can use this command to change the dictionary default path") "--dict",
type=str,
help="you can use this command to change the dictionary default path"
)
self.add_argument( self.add_argument(
"--data_dir",type=str,help="you can use this command to change the dataset default root path") "--data_dir",
type=str,
help="you can use this command to change the dataset default root path"
)
def parse_args(self, argv=None): def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv) args = super(ArgsParser, self).parse_args(argv)
...@@ -68,21 +118,37 @@ class ArgsParser(ArgumentParser): ...@@ -68,21 +118,37 @@ class ArgsParser(ArgumentParser):
return config return config
def _set_language(self, type): def _set_language(self, type):
assert(type),"please use -l or --language to choose language type" lang = type[0]
assert (type), "please use -l or --language to choose language type"
assert( assert(
type[0] in support_list.keys() lang in support_list.keys() or lang in multi_lang
),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \ ),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
"please check your running command".format(support_list, type) "please check your running command".format(multi_lang, type)
global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0]) if lang in latin_lang:
global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0]) lang = "latin"
global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])] elif lang in arabic_lang:
global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])] lang = "arabic"
global_config['Global']['character_type'] = type[0] elif lang in cyrillic_lang:
assert( lang = "cyrillic"
os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path'])) elif lang in devanagari_lang:
),"Loss default dictionary file {}_dict.txt.You can download it from \ lang = "devanagari"
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0]) global_config['Global'][
return type[0] 'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang)
global_config['Global'][
'save_model_dir'] = './output/rec_{}_lite'.format(lang)
global_config['Train']['dataset'][
'label_file_list'] = ["train_data/{}_train.txt".format(lang)]
global_config['Eval']['dataset'][
'label_file_list'] = ["train_data/{}_val.txt".format(lang)]
global_config['Global']['character_type'] = lang
assert (
os.path.isfile(
os.path.join(project_path, global_config['Global'][
'character_dict_path']))
), "Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
lang)
return lang
def merge_config(config): def merge_config(config):
...@@ -110,43 +176,51 @@ def merge_config(config): ...@@ -110,43 +176,51 @@ def merge_config(config):
cur[sub_key] = value cur[sub_key] = value
else: else:
cur = cur[sub_key] cur = cur[sub_key]
def loss_file(path): def loss_file(path):
assert( assert (
os.path.exists(path) os.path.exists(path)
),"There is no such file:{},Please do not forget to put in the specified file".format(path) ), "There is no such file:{},Please do not forget to put in the specified file".format(
path)
if __name__ == '__main__': if __name__ == '__main__':
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language) save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
if os.path.isfile(save_file_path): if os.path.isfile(save_file_path):
os.remove(save_file_path) os.remove(save_file_path)
if FLAGS.train: if FLAGS.train:
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train] global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
train_label_path = os.path.join(project_path,FLAGS.train) train_label_path = os.path.join(project_path, FLAGS.train)
loss_file(train_label_path) loss_file(train_label_path)
if FLAGS.val: if FLAGS.val:
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val] global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
eval_label_path = os.path.join(project_path,FLAGS.val) eval_label_path = os.path.join(project_path, FLAGS.val)
loss_file(eval_label_path) loss_file(eval_label_path)
if FLAGS.dict: if FLAGS.dict:
global_config['Global']['character_dict_path'] = FLAGS.dict global_config['Global']['character_dict_path'] = FLAGS.dict
dict_path = os.path.join(project_path,FLAGS.dict) dict_path = os.path.join(project_path, FLAGS.dict)
loss_file(dict_path) loss_file(dict_path)
if FLAGS.data_dir: if FLAGS.data_dir:
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
data_dir = os.path.join(project_path,FLAGS.data_dir) data_dir = os.path.join(project_path, FLAGS.data_dir)
loss_file(data_dir) loss_file(data_dir)
with open(save_file_path, 'w') as f: with open(save_file_path, 'w') as f:
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False) yaml.dump(
dict(global_config), f, default_flow_style=False, sort_keys=False)
logging.info("Project path is :{}".format(project_path)) logging.info("Project path is :{}".format(project_path))
logging.info("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0])) logging.info("Train list path set to :{}".format(global_config['Train'][
logging.info("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0])) 'dataset']['label_file_list'][0]))
logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir'])) logging.info("Eval list path set to :{}".format(global_config['Eval'][
logging.info("Dict path set to :{}".format(global_config['Global']['character_dict_path'])) 'dataset']['label_file_list'][0]))
logging.info("Config file set to :configs/rec/multi_language/{}".format(save_file_path)) logging.info("Dataset root path set to :{}".format(global_config['Eval'][
'dataset']['data_dir']))
logging.info("Dict path set to :{}".format(global_config['Global'][
'character_dict_path']))
logging.info("Config file set to :configs/rec/multi_language/{}".
format(save_file_path))
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_arabic_lite
save_epoch_step: 3
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/arabic_dict.txt
character_type: arabic
max_text_length: 25
infer_mode: false
use_space_char: true
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: rec
algorithm: CRNN
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride:
- 1
- 2
- 2
- 2
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 1.0e-05
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/arabic_train.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug: null
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 256
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/arabic_val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 256
num_workers: 8
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_cyrillic_lite
save_epoch_step: 3
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/cyrillic_dict.txt
character_type: cyrillic
max_text_length: 25
infer_mode: false
use_space_char: true
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: rec
algorithm: CRNN
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride:
- 1
- 2
- 2
- 2
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 1.0e-05
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/cyrillic_train.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug: null
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 256
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/cyrillic_val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 256
num_workers: 8
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_devanagari_lite
save_epoch_step: 3
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/devanagari_dict.txt
character_type: devanagari
max_text_length: 25
infer_mode: false
use_space_char: true
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: rec
algorithm: CRNN
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride:
- 1
- 2
- 2
- 2
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 1.0e-05
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/devanagari_train.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug: null
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 256
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/devanagari_val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 256
num_workers: 8
...@@ -15,11 +15,11 @@ Global: ...@@ -15,11 +15,11 @@ Global:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/en_dict.txt character_dict_path: ppocr/utils/en_dict.txt
character_type: EN character_type: EN
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: True
Optimizer: Optimizer:
......
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_latin_lite
save_epoch_step: 3
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/latin_dict.txt
character_type: latin
max_text_length: 25
infer_mode: false
use_space_char: true
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: rec
algorithm: CRNN
Transform: null
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride:
- 1
- 2
- 2
- 2
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 1.0e-05
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/latin_train.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug: null
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 256
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/
label_file_list:
- train_data/latin_val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- RecResizeImg:
image_shape:
- 3
- 32
- 320
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 256
num_workers: 8
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_ic15.txt
Optimizer: Optimizer:
name: Adam name: Adam
...@@ -81,7 +82,7 @@ Eval: ...@@ -81,7 +82,7 @@ Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/ data_dir: ./train_data/
label_file_list: ["./train_data/train_list.txt"] label_file_list: ["./train_data/val_list.txt"]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer: Optimizer:
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt
Optimizer: Optimizer:
name: Adam name: Adam
...@@ -37,7 +38,7 @@ Architecture: ...@@ -37,7 +38,7 @@ Architecture:
name: TPS name: TPS
num_fiducial: 20 num_fiducial: 20
loc_lr: 0.1 loc_lr: 0.1
model_name: small model_name: large
Backbone: Backbone:
name: ResNet name: ResNet
layers: 34 layers: 34
......
...@@ -20,6 +20,7 @@ Global: ...@@ -20,6 +20,7 @@ Global:
num_heads: 8 num_heads: 8
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_srn.txt
Optimizer: Optimizer:
......
*.iml
.gradle
/local.properties
/.idea/*
.DS_Store
/build
/captures
.externalNativeBuild
# 如何快速测试
### 1. 安装最新版本的Android Studio
可以从 https://developer.android.com/studio 下载。本Demo使用是4.0版本Android Studio编写。
### 2. 按照NDK 20 以上版本
Demo测试的时候使用的是NDK 20b版本,20版本以上均可以支持编译成功。
如果您是初学者,可以用以下方式安装和测试NDK编译环境。
点击 File -> New ->New Project, 新建 "Native C++" project
### 3. 导入项目
点击 File->New->Import Project..., 然后跟着Android Studio的引导导入
# 获得更多支持
前往[端计算模型生成平台EasyEdge](https://ai.baidu.com/easyedge/app/open_source_demo?referrerUrl=paddlelite),获得更多开发支持:
- Demo APP:可使用手机扫码安装,方便手机端快速体验文字识别
- SDK:模型被封装为适配不同芯片硬件和操作系统SDK,包括完善的接口,方便进行二次开发
import java.security.MessageDigest
apply plugin: 'com.android.application'
android {
compileSdkVersion 29
defaultConfig {
applicationId "com.baidu.paddle.lite.demo.ocr"
minSdkVersion 23
targetSdkVersion 29
versionCode 1
versionName "1.0"
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
externalNativeBuild {
cmake {
cppFlags "-std=c++11 -frtti -fexceptions -Wno-format"
arguments '-DANDROID_PLATFORM=android-23', '-DANDROID_STL=c++_shared' ,"-DANDROID_ARM_NEON=TRUE"
}
}
ndk {
// abiFilters "arm64-v8a", "armeabi-v7a"
abiFilters "arm64-v8a", "armeabi-v7a"
ldLibs "jnigraphics"
}
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
externalNativeBuild {
cmake {
path "src/main/cpp/CMakeLists.txt"
version "3.10.2"
}
}
}
dependencies {
implementation fileTree(include: ['*.jar'], dir: 'libs')
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'com.android.support.test:runner:1.0.2'
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
}
def archives = [
[
'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/paddle_lite_libs_v2_9_0.tar.gz',
'dest': 'PaddleLite'
],
[
'src' : 'https://paddlelite-demo.bj.bcebos.com/libs/android/opencv-4.2.0-android-sdk.tar.gz',
'dest': 'OpenCV'
],
[
'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ocr_v2_for_cpu.tar.gz',
'dest' : 'src/main/assets/models'
],
[
'src' : 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_dict.tar.gz',
'dest' : 'src/main/assets/labels'
]
]
task downloadAndExtractArchives(type: DefaultTask) {
doFirst {
println "Downloading and extracting archives including libs and models"
}
doLast {
// Prepare cache folder for archives
String cachePath = "cache"
if (!file("${cachePath}").exists()) {
mkdir "${cachePath}"
}
archives.eachWithIndex { archive, index ->
MessageDigest messageDigest = MessageDigest.getInstance('MD5')
messageDigest.update(archive.src.bytes)
String cacheName = new BigInteger(1, messageDigest.digest()).toString(32)
// Download the target archive if not exists
boolean copyFiles = !file("${archive.dest}").exists()
if (!file("${cachePath}/${cacheName}.tar.gz").exists()) {
ant.get(src: archive.src, dest: file("${cachePath}/${cacheName}.tar.gz"))
copyFiles = true; // force to copy files from the latest archive files
}
// Extract the target archive if its dest path does not exists
if (copyFiles) {
copy {
from tarTree("${cachePath}/${cacheName}.tar.gz")
into "${archive.dest}"
}
}
}
}
}
preBuild.dependsOn downloadAndExtractArchives
\ No newline at end of file
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
package com.baidu.paddle.lite.demo.ocr;
import android.content.Context;
import android.support.test.InstrumentationRegistry;
import android.support.test.runner.AndroidJUnit4;
import org.junit.Test;
import org.junit.runner.RunWith;
import static org.junit.Assert.*;
/**
* Instrumented test, which will execute on an Android device.
*
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
*/
@RunWith(AndroidJUnit4.class)
public class ExampleInstrumentedTest {
@Test
public void useAppContext() {
// Context of the app under test.
Context appContext = InstrumentationRegistry.getTargetContext();
assertEquals("com.baidu.paddle.lite.demo", appContext.getPackageName());
}
}
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.baidu.paddle.lite.demo.ocr">
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.CAMERA"/>
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<!-- to test MiniActivity, change this to com.baidu.paddle.lite.demo.ocr.MiniActivity -->
<activity android:name="com.baidu.paddle.lite.demo.ocr.MainActivity">
<intent-filter>
<action android:name="android.intent.action.MAIN"/>
<category android:name="android.intent.category.LAUNCHER"/>
</intent-filter>
</activity>
<activity
android:name="com.baidu.paddle.lite.demo.ocr.SettingsActivity"
android:label="Settings">
</activity>
<provider
android:name="androidx.core.content.FileProvider"
android:authorities="com.baidu.paddle.lite.demo.ocr.fileprovider"
android:exported="false"
android:grantUriPermissions="true">
<meta-data
android:name="android.support.FILE_PROVIDER_PATHS"
android:resource="@xml/file_paths"></meta-data>
</provider>
</application>
</manifest>
\ No newline at end of file
# For more information about using CMake with Android Studio, read the
# documentation: https://d.android.com/studio/projects/add-native-code.html
# Sets the minimum version of CMake required to build the native library.
cmake_minimum_required(VERSION 3.4.1)
# Creates and names a library, sets it as either STATIC or SHARED, and provides
# the relative paths to its source code. You can define multiple libraries, and
# CMake builds them for you. Gradle automatically packages shared libraries with
# your APK.
set(PaddleLite_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../PaddleLite")
include_directories(${PaddleLite_DIR}/cxx/include)
set(OpenCV_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../OpenCV/sdk/native/jni")
message(STATUS "opencv dir: ${OpenCV_DIR}")
find_package(OpenCV REQUIRED)
message(STATUS "OpenCV libraries: ${OpenCV_LIBS}")
include_directories(${OpenCV_INCLUDE_DIRS})
aux_source_directory(. SOURCES)
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -ffast-math -Ofast -Os"
)
set(CMAKE_CXX_FLAGS
"${CMAKE_CXX_FLAGS} -fvisibility=hidden -fvisibility-inlines-hidden -fdata-sections -ffunction-sections"
)
set(CMAKE_SHARED_LINKER_FLAGS
"${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections -Wl,-z,nocopyreloc")
add_library(
# Sets the name of the library.
Native
# Sets the library as a shared library.
SHARED
# Provides a relative path to your source file(s).
${SOURCES})
find_library(
# Sets the name of the path variable.
log-lib
# Specifies the name of the NDK library that you want CMake to locate.
log)
add_library(
# Sets the name of the library.
paddle_light_api_shared
# Sets the library as a shared library.
SHARED
# Provides a relative path to your source file(s).
IMPORTED)
set_target_properties(
# Specifies the target library.
paddle_light_api_shared
# Specifies the parameter you want to define.
PROPERTIES
IMPORTED_LOCATION
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libpaddle_light_api_shared.so
# Provides the path to the library you want to import.
)
# Specifies libraries CMake should link to your target library. You can link
# multiple libraries, such as libraries you define in this build script,
# prebuilt third-party libraries, or system libraries.
target_link_libraries(
# Specifies the target library.
Native
paddle_light_api_shared
${OpenCV_LIBS}
GLESv2
EGL
jnigraphics
${log-lib}
)
add_custom_command(
TARGET Native
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libc++_shared.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libc++_shared.so)
add_custom_command(
TARGET Native
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libpaddle_light_api_shared.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libpaddle_light_api_shared.so)
add_custom_command(
TARGET Native
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai.so)
add_custom_command(
TARGET Native
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai_ir.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai_ir.so)
add_custom_command(
TARGET Native
POST_BUILD
COMMAND
${CMAKE_COMMAND} -E copy
${PaddleLite_DIR}/cxx/libs/${ANDROID_ABI}/libhiai_ir_build.so
${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/libhiai_ir_build.so)
\ No newline at end of file
//
// Created by fu on 4/25/18.
//
#pragma once
#import <numeric>
#import <vector>
#ifdef __ANDROID__
#include <android/log.h>
#define LOG_TAG "OCR_NDK"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, LOG_TAG, __VA_ARGS__)
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
#else
#include <stdio.h>
#define LOGI(format, ...) \
fprintf(stdout, "[" LOG_TAG "]" format "\n", ##__VA_ARGS__)
#define LOGW(format, ...) \
fprintf(stdout, "[" LOG_TAG "]" format "\n", ##__VA_ARGS__)
#define LOGE(format, ...) \
fprintf(stderr, "[" LOG_TAG "]Error: " format "\n", ##__VA_ARGS__)
#endif
enum RETURN_CODE { RETURN_OK = 0 };
enum NET_TYPE { NET_OCR = 900100, NET_OCR_INTERNAL = 991008 };
template <typename T> inline T product(const std::vector<T> &vec) {
if (vec.empty()) {
return 0;
}
return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies<T>());
}
//
// Created by fujiayi on 2020/7/5.
//
#include "native.h"
#include "ocr_ppredictor.h"
#include <algorithm>
#include <paddle_api.h>
#include <string>
static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
extern "C" JNIEXPORT jlong JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
JNIEnv *env, jobject thiz, jstring j_det_model_path,
jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num,
jstring j_cpu_mode) {
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path);
int thread_num = j_thread_num;
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
ppredictor::OCR_Config conf;
conf.thread_num = thread_num;
conf.mode = str_to_cpu_mode(cpu_mode);
ppredictor::OCR_PPredictor *orc_predictor =
new ppredictor::OCR_PPredictor{conf};
orc_predictor->init_from_file(det_model_path, rec_model_path, cls_model_path);
return reinterpret_cast<jlong>(orc_predictor);
}
/**
* "LITE_POWER_HIGH" convert to paddle::lite_api::LITE_POWER_HIGH
* @param cpu_mode
* @return
*/
static paddle::lite_api::PowerMode
str_to_cpu_mode(const std::string &cpu_mode) {
static std::map<std::string, paddle::lite_api::PowerMode> cpu_mode_map{
{"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH},
{"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH},
{"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL},
{"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND},
{"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH},
{"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}};
std::string upper_key;
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(),
::toupper);
auto index = cpu_mode_map.find(upper_key);
if (index == cpu_mode_map.end()) {
LOGE("cpu_mode not found %s", upper_key.c_str());
return paddle::lite_api::LITE_POWER_HIGH;
} else {
return index->second;
}
}
extern "C" JNIEXPORT jfloatArray JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf,
jfloatArray ddims, jobject original_image) {
LOGI("begin to run native forward");
if (java_pointer == 0) {
LOGE("JAVA pointer is NULL");
return cpp_array_to_jfloatarray(env, nullptr, 0);
}
cv::Mat origin = bitmap_to_cv_mat(env, original_image);
if (origin.size == 0) {
LOGE("origin bitmap cannot convert to CV Mat");
return cpp_array_to_jfloatarray(env, nullptr, 0);
}
ppredictor::OCR_PPredictor *ppredictor =
(ppredictor::OCR_PPredictor *)java_pointer;
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
std::vector<int64_t> dims_arr;
dims_arr.resize(dims_float_arr.size());
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
// 这里值有点大,就不调用jfloatarray_to_float_vector了
int64_t buf_len = (int64_t)env->GetArrayLength(buf);
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
float *data = (jfloat *)buf_data;
std::vector<ppredictor::OCRPredictResult> results =
ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin);
LOGI("infer_ocr finished with boxes %ld", results.size());
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
// float数组,传输到java层再反序列化
std::vector<float> float_arr;
for (const ppredictor::OCRPredictResult &r : results) {
float_arr.push_back(r.points.size());
float_arr.push_back(r.word_index.size());
float_arr.push_back(r.score);
for (const std::vector<int> &point : r.points) {
float_arr.push_back(point.at(0));
float_arr.push_back(point.at(1));
}
for (int index : r.word_index) {
float_arr.push_back(index);
}
}
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
}
extern "C" JNIEXPORT void JNICALL
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(
JNIEnv *env, jobject thiz, jlong java_pointer) {
if (java_pointer == 0) {
LOGE("JAVA pointer is NULL");
return;
}
ppredictor::OCR_PPredictor *ppredictor =
(ppredictor::OCR_PPredictor *)java_pointer;
delete ppredictor;
}
\ No newline at end of file
//
// Created by fujiayi on 2020/7/5.
//
#pragma once
#include "common.h"
#include <android/bitmap.h>
#include <jni.h>
#include <opencv2/opencv.hpp>
#include <string>
#include <vector>
inline std::string jstring_to_cpp_string(JNIEnv *env, jstring jstr) {
// In java, a unicode char will be encoded using 2 bytes (utf16).
// so jstring will contain characters utf16. std::string in c++ is
// essentially a string of bytes, not characters, so if we want to
// pass jstring from JNI to c++, we have convert utf16 to bytes.
if (!jstr) {
return "";
}
const jclass stringClass = env->GetObjectClass(jstr);
const jmethodID getBytes =
env->GetMethodID(stringClass, "getBytes", "(Ljava/lang/String;)[B");
const jbyteArray stringJbytes = (jbyteArray)env->CallObjectMethod(
jstr, getBytes, env->NewStringUTF("UTF-8"));
size_t length = (size_t)env->GetArrayLength(stringJbytes);
jbyte *pBytes = env->GetByteArrayElements(stringJbytes, NULL);
std::string ret = std::string(reinterpret_cast<char *>(pBytes), length);
env->ReleaseByteArrayElements(stringJbytes, pBytes, JNI_ABORT);
env->DeleteLocalRef(stringJbytes);
env->DeleteLocalRef(stringClass);
return ret;
}
inline jstring cpp_string_to_jstring(JNIEnv *env, std::string str) {
auto *data = str.c_str();
jclass strClass = env->FindClass("java/lang/String");
jmethodID strClassInitMethodID =
env->GetMethodID(strClass, "<init>", "([BLjava/lang/String;)V");
jbyteArray bytes = env->NewByteArray(strlen(data));
env->SetByteArrayRegion(bytes, 0, strlen(data),
reinterpret_cast<const jbyte *>(data));
jstring encoding = env->NewStringUTF("UTF-8");
jstring res = (jstring)(
env->NewObject(strClass, strClassInitMethodID, bytes, encoding));
env->DeleteLocalRef(strClass);
env->DeleteLocalRef(encoding);
env->DeleteLocalRef(bytes);
return res;
}
inline jfloatArray cpp_array_to_jfloatarray(JNIEnv *env, const float *buf,
int64_t len) {
if (len == 0) {
return env->NewFloatArray(0);
}
jfloatArray result = env->NewFloatArray(len);
env->SetFloatArrayRegion(result, 0, len, buf);
return result;
}
inline jintArray cpp_array_to_jintarray(JNIEnv *env, const int *buf,
int64_t len) {
jintArray result = env->NewIntArray(len);
env->SetIntArrayRegion(result, 0, len, buf);
return result;
}
inline jbyteArray cpp_array_to_jbytearray(JNIEnv *env, const int8_t *buf,
int64_t len) {
jbyteArray result = env->NewByteArray(len);
env->SetByteArrayRegion(result, 0, len, buf);
return result;
}
inline jlongArray int64_vector_to_jlongarray(JNIEnv *env,
const std::vector<int64_t> &vec) {
jlongArray result = env->NewLongArray(vec.size());
jlong *buf = new jlong[vec.size()];
for (size_t i = 0; i < vec.size(); ++i) {
buf[i] = (jlong)vec[i];
}
env->SetLongArrayRegion(result, 0, vec.size(), buf);
delete[] buf;
return result;
}
inline std::vector<int64_t> jlongarray_to_int64_vector(JNIEnv *env,
jlongArray data) {
int data_size = env->GetArrayLength(data);
jlong *data_ptr = env->GetLongArrayElements(data, nullptr);
std::vector<int64_t> data_vec(data_ptr, data_ptr + data_size);
env->ReleaseLongArrayElements(data, data_ptr, 0);
return data_vec;
}
inline std::vector<float> jfloatarray_to_float_vector(JNIEnv *env,
jfloatArray data) {
int data_size = env->GetArrayLength(data);
jfloat *data_ptr = env->GetFloatArrayElements(data, nullptr);
std::vector<float> data_vec(data_ptr, data_ptr + data_size);
env->ReleaseFloatArrayElements(data, data_ptr, 0);
return data_vec;
}
inline cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap) {
AndroidBitmapInfo info;
int result = AndroidBitmap_getInfo(env, bitmap, &info);
if (result != ANDROID_BITMAP_RESULT_SUCCESS) {
LOGE("AndroidBitmap_getInfo failed, result: %d", result);
return cv::Mat{};
}
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
LOGE("Bitmap format is not RGBA_8888 !");
return cv::Mat{};
}
unsigned char *srcData = NULL;
AndroidBitmap_lockPixels(env, bitmap, (void **)&srcData);
cv::Mat mat = cv::Mat::zeros(info.height, info.width, CV_8UC4);
memcpy(mat.data, srcData, info.height * info.width * 4);
AndroidBitmap_unlockPixels(env, bitmap);
cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGR);
/**
if (!cv::imwrite("/sdcard/1/copy.jpg", mat)){
LOGE("Write image failed " );
}
*/
return mat;
}
/*******************************************************************************
* *
* Author : Angus Johnson *
* Version : 6.4.2 *
* Date : 27 February 2017 *
* Website : http://www.angusj.com *
* Copyright : Angus Johnson 2010-2017 *
* *
* License: *
* Use, modification & distribution is subject to Boost Software License Ver 1. *
* http://www.boost.org/LICENSE_1_0.txt *
* *
* Attributions: *
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
* "A generic solution to polygon clipping" *
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
* http://portal.acm.org/citation.cfm?id=129906 *
* *
* Computer graphics and geometric modeling: implementation and algorithms *
* By Max K. Agoston *
* Springer; 1 edition (January 4, 2005) *
* http://books.google.com/books?q=vatti+clipping+agoston *
* *
* See also: *
* "Polygon Offsetting by Computing Winding Numbers" *
* Paper no. DETC2005-85513 pp. 565-575 *
* ASME 2005 International Design Engineering Technical Conferences *
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
* September 24-28, 2005 , Long Beach, California, USA *
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
* *
*******************************************************************************/
/*******************************************************************************
* *
* This is a translation of the Delphi Clipper library and the naming style *
* used has retained a Delphi flavour. *
* *
*******************************************************************************/
#include "ocr_clipper.hpp"
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <functional>
#include <ostream>
#include <stdexcept>
#include <vector>
namespace ClipperLib {
static double const pi = 3.141592653589793238;
static double const two_pi = pi * 2;
static double const def_arc_tolerance = 0.25;
enum Direction { dRightToLeft, dLeftToRight };
static int const Unassigned = -1; // edge not currently 'owning' a solution
static int const Skip = -2; // edge that would otherwise close a path
#define HORIZONTAL (-1.0E+40)
#define TOLERANCE (1.0e-20)
#define NEAR_ZERO(val) (((val) > -TOLERANCE) && ((val) < TOLERANCE))
struct TEdge {
IntPoint Bot;
IntPoint Curr; // current (updated for every new scanbeam)
IntPoint Top;
double Dx;
PolyType PolyTyp;
EdgeSide Side; // side only refers to current side of solution poly
int WindDelta; // 1 or -1 depending on winding direction
int WindCnt;
int WindCnt2; // winding count of the opposite polytype
int OutIdx;
TEdge *Next;
TEdge *Prev;
TEdge *NextInLML;
TEdge *NextInAEL;
TEdge *PrevInAEL;
TEdge *NextInSEL;
TEdge *PrevInSEL;
};
struct IntersectNode {
TEdge *Edge1;
TEdge *Edge2;
IntPoint Pt;
};
struct LocalMinimum {
cInt Y;
TEdge *LeftBound;
TEdge *RightBound;
};
struct OutPt;
// OutRec: contains a path in the clipping solution. Edges in the AEL will
// carry a pointer to an OutRec when they are part of the clipping solution.
struct OutRec {
int Idx;
bool IsHole;
bool IsOpen;
OutRec *FirstLeft; // see comments in clipper.pas
PolyNode *PolyNd;
OutPt *Pts;
OutPt *BottomPt;
};
struct OutPt {
int Idx;
IntPoint Pt;
OutPt *Next;
OutPt *Prev;
};
struct Join {
OutPt *OutPt1;
OutPt *OutPt2;
IntPoint OffPt;
};
struct LocMinSorter {
inline bool operator()(const LocalMinimum &locMin1,
const LocalMinimum &locMin2) {
return locMin2.Y < locMin1.Y;
}
};
//------------------------------------------------------------------------------
//------------------------------------------------------------------------------
inline cInt Round(double val) {
if ((val < 0))
return static_cast<cInt>(val - 0.5);
else
return static_cast<cInt>(val + 0.5);
}
//------------------------------------------------------------------------------
inline cInt Abs(cInt val) { return val < 0 ? -val : val; }
//------------------------------------------------------------------------------
// PolyTree methods ...
//------------------------------------------------------------------------------
void PolyTree::Clear() {
for (PolyNodes::size_type i = 0; i < AllNodes.size(); ++i)
delete AllNodes[i];
AllNodes.resize(0);
Childs.resize(0);
}
//------------------------------------------------------------------------------
PolyNode *PolyTree::GetFirst() const {
if (!Childs.empty())
return Childs[0];
else
return 0;
}
//------------------------------------------------------------------------------
int PolyTree::Total() const {
int result = (int)AllNodes.size();
// with negative offsets, ignore the hidden outer polygon ...
if (result > 0 && Childs[0] != AllNodes[0])
result--;
return result;
}
//------------------------------------------------------------------------------
// PolyNode methods ...
//------------------------------------------------------------------------------
PolyNode::PolyNode() : Parent(0), Index(0), m_IsOpen(false) {}
//------------------------------------------------------------------------------
int PolyNode::ChildCount() const { return (int)Childs.size(); }
//------------------------------------------------------------------------------
void PolyNode::AddChild(PolyNode &child) {
unsigned cnt = (unsigned)Childs.size();
Childs.push_back(&child);
child.Parent = this;
child.Index = cnt;
}
//------------------------------------------------------------------------------
PolyNode *PolyNode::GetNext() const {
if (!Childs.empty())
return Childs[0];
else
return GetNextSiblingUp();
}
//------------------------------------------------------------------------------
PolyNode *PolyNode::GetNextSiblingUp() const {
if (!Parent) // protects against PolyTree.GetNextSiblingUp()
return 0;
else if (Index == Parent->Childs.size() - 1)
return Parent->GetNextSiblingUp();
else
return Parent->Childs[Index + 1];
}
//------------------------------------------------------------------------------
bool PolyNode::IsHole() const {
bool result = true;
PolyNode *node = Parent;
while (node) {
result = !result;
node = node->Parent;
}
return result;
}
//------------------------------------------------------------------------------
bool PolyNode::IsOpen() const { return m_IsOpen; }
//------------------------------------------------------------------------------
#ifndef use_int32
//------------------------------------------------------------------------------
// Int128 class (enables safe math on signed 64bit integers)
// eg Int128 val1((long64)9223372036854775807); //ie 2^63 -1
// Int128 val2((long64)9223372036854775807);
// Int128 val3 = val1 * val2;
// val3.AsString => "85070591730234615847396907784232501249" (8.5e+37)
//------------------------------------------------------------------------------
class Int128 {
public:
ulong64 lo;
long64 hi;
Int128(long64 _lo = 0) {
lo = (ulong64)_lo;
if (_lo < 0)
hi = -1;
else
hi = 0;
}
Int128(const Int128 &val) : lo(val.lo), hi(val.hi) {}
Int128(const long64 &_hi, const ulong64 &_lo) : lo(_lo), hi(_hi) {}
Int128 &operator=(const long64 &val) {
lo = (ulong64)val;
if (val < 0)
hi = -1;
else
hi = 0;
return *this;
}
bool operator==(const Int128 &val) const {
return (hi == val.hi && lo == val.lo);
}
bool operator!=(const Int128 &val) const { return !(*this == val); }
bool operator>(const Int128 &val) const {
if (hi != val.hi)
return hi > val.hi;
else
return lo > val.lo;
}
bool operator<(const Int128 &val) const {
if (hi != val.hi)
return hi < val.hi;
else
return lo < val.lo;
}
bool operator>=(const Int128 &val) const { return !(*this < val); }
bool operator<=(const Int128 &val) const { return !(*this > val); }
Int128 &operator+=(const Int128 &rhs) {
hi += rhs.hi;
lo += rhs.lo;
if (lo < rhs.lo)
hi++;
return *this;
}
Int128 operator+(const Int128 &rhs) const {
Int128 result(*this);
result += rhs;
return result;
}
Int128 &operator-=(const Int128 &rhs) {
*this += -rhs;
return *this;
}
Int128 operator-(const Int128 &rhs) const {
Int128 result(*this);
result -= rhs;
return result;
}
Int128 operator-() const // unary negation
{
if (lo == 0)
return Int128(-hi, 0);
else
return Int128(~hi, ~lo + 1);
}
operator double() const {
const double shift64 = 18446744073709551616.0; // 2^64
if (hi < 0) {
if (lo == 0)
return (double)hi * shift64;
else
return -(double)(~lo + ~hi * shift64);
} else
return (double)(lo + hi * shift64);
}
};
//------------------------------------------------------------------------------
Int128 Int128Mul(long64 lhs, long64 rhs) {
bool negate = (lhs < 0) != (rhs < 0);
if (lhs < 0)
lhs = -lhs;
ulong64 int1Hi = ulong64(lhs) >> 32;
ulong64 int1Lo = ulong64(lhs & 0xFFFFFFFF);
if (rhs < 0)
rhs = -rhs;
ulong64 int2Hi = ulong64(rhs) >> 32;
ulong64 int2Lo = ulong64(rhs & 0xFFFFFFFF);
// nb: see comments in clipper.pas
ulong64 a = int1Hi * int2Hi;
ulong64 b = int1Lo * int2Lo;
ulong64 c = int1Hi * int2Lo + int1Lo * int2Hi;
Int128 tmp;
tmp.hi = long64(a + (c >> 32));
tmp.lo = long64(c << 32);
tmp.lo += long64(b);
if (tmp.lo < b)
tmp.hi++;
if (negate)
tmp = -tmp;
return tmp;
};
#endif
//------------------------------------------------------------------------------
// Miscellaneous global functions
//------------------------------------------------------------------------------
bool Orientation(const Path &poly) { return Area(poly) >= 0; }
//------------------------------------------------------------------------------
double Area(const Path &poly) {
int size = (int)poly.size();
if (size < 3)
return 0;
double a = 0;
for (int i = 0, j = size - 1; i < size; ++i) {
a += ((double)poly[j].X + poly[i].X) * ((double)poly[j].Y - poly[i].Y);
j = i;
}
return -a * 0.5;
}
//------------------------------------------------------------------------------
double Area(const OutPt *op) {
const OutPt *startOp = op;
if (!op)
return 0;
double a = 0;
do {
a += (double)(op->Prev->Pt.X + op->Pt.X) *
(double)(op->Prev->Pt.Y - op->Pt.Y);
op = op->Next;
} while (op != startOp);
return a * 0.5;
}
//------------------------------------------------------------------------------
double Area(const OutRec &outRec) { return Area(outRec.Pts); }
//------------------------------------------------------------------------------
bool PointIsVertex(const IntPoint &Pt, OutPt *pp) {
OutPt *pp2 = pp;
do {
if (pp2->Pt == Pt)
return true;
pp2 = pp2->Next;
} while (pp2 != pp);
return false;
}
//------------------------------------------------------------------------------
// See "The Point in Polygon Problem for Arbitrary Polygons" by Hormann &
// Agathos
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.88.5498&rep=rep1&type=pdf
int PointInPolygon(const IntPoint &pt, const Path &path) {
// returns 0 if false, +1 if true, -1 if pt ON polygon boundary
int result = 0;
size_t cnt = path.size();
if (cnt < 3)
return 0;
IntPoint ip = path[0];
for (size_t i = 1; i <= cnt; ++i) {
IntPoint ipNext = (i == cnt ? path[0] : path[i]);
if (ipNext.Y == pt.Y) {
if ((ipNext.X == pt.X) ||
(ip.Y == pt.Y && ((ipNext.X > pt.X) == (ip.X < pt.X))))
return -1;
}
if ((ip.Y < pt.Y) != (ipNext.Y < pt.Y)) {
if (ip.X >= pt.X) {
if (ipNext.X > pt.X)
result = 1 - result;
else {
double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) -
(double)(ipNext.X - pt.X) * (ip.Y - pt.Y);
if (!d)
return -1;
if ((d > 0) == (ipNext.Y > ip.Y))
result = 1 - result;
}
} else {
if (ipNext.X > pt.X) {
double d = (double)(ip.X - pt.X) * (ipNext.Y - pt.Y) -
(double)(ipNext.X - pt.X) * (ip.Y - pt.Y);
if (!d)
return -1;
if ((d > 0) == (ipNext.Y > ip.Y))
result = 1 - result;
}
}
}
ip = ipNext;
}
return result;
}
//------------------------------------------------------------------------------
int PointInPolygon(const IntPoint &pt, OutPt *op) {
// returns 0 if false, +1 if true, -1 if pt ON polygon boundary
int result = 0;
OutPt *startOp = op;
for (;;) {
if (op->Next->Pt.Y == pt.Y) {
if ((op->Next->Pt.X == pt.X) ||
(op->Pt.Y == pt.Y && ((op->Next->Pt.X > pt.X) == (op->Pt.X < pt.X))))
return -1;
}
if ((op->Pt.Y < pt.Y) != (op->Next->Pt.Y < pt.Y)) {
if (op->Pt.X >= pt.X) {
if (op->Next->Pt.X > pt.X)
result = 1 - result;
else {
double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) -
(double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y);
if (!d)
return -1;
if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y))
result = 1 - result;
}
} else {
if (op->Next->Pt.X > pt.X) {
double d = (double)(op->Pt.X - pt.X) * (op->Next->Pt.Y - pt.Y) -
(double)(op->Next->Pt.X - pt.X) * (op->Pt.Y - pt.Y);
if (!d)
return -1;
if ((d > 0) == (op->Next->Pt.Y > op->Pt.Y))
result = 1 - result;
}
}
}
op = op->Next;
if (startOp == op)
break;
}
return result;
}
//------------------------------------------------------------------------------
bool Poly2ContainsPoly1(OutPt *OutPt1, OutPt *OutPt2) {
OutPt *op = OutPt1;
do {
// nb: PointInPolygon returns 0 if false, +1 if true, -1 if pt on polygon
int res = PointInPolygon(op->Pt, OutPt2);
if (res >= 0)
return res > 0;
op = op->Next;
} while (op != OutPt1);
return true;
}
//----------------------------------------------------------------------
bool SlopesEqual(const TEdge &e1, const TEdge &e2, bool UseFullInt64Range) {
#ifndef use_int32
if (UseFullInt64Range)
return Int128Mul(e1.Top.Y - e1.Bot.Y, e2.Top.X - e2.Bot.X) ==
Int128Mul(e1.Top.X - e1.Bot.X, e2.Top.Y - e2.Bot.Y);
else
#endif
return (e1.Top.Y - e1.Bot.Y) * (e2.Top.X - e2.Bot.X) ==
(e1.Top.X - e1.Bot.X) * (e2.Top.Y - e2.Bot.Y);
}
//------------------------------------------------------------------------------
bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, const IntPoint pt3,
bool UseFullInt64Range) {
#ifndef use_int32
if (UseFullInt64Range)
return Int128Mul(pt1.Y - pt2.Y, pt2.X - pt3.X) ==
Int128Mul(pt1.X - pt2.X, pt2.Y - pt3.Y);
else
#endif
return (pt1.Y - pt2.Y) * (pt2.X - pt3.X) ==
(pt1.X - pt2.X) * (pt2.Y - pt3.Y);
}
//------------------------------------------------------------------------------
bool SlopesEqual(const IntPoint pt1, const IntPoint pt2, const IntPoint pt3,
const IntPoint pt4, bool UseFullInt64Range) {
#ifndef use_int32
if (UseFullInt64Range)
return Int128Mul(pt1.Y - pt2.Y, pt3.X - pt4.X) ==
Int128Mul(pt1.X - pt2.X, pt3.Y - pt4.Y);
else
#endif
return (pt1.Y - pt2.Y) * (pt3.X - pt4.X) ==
(pt1.X - pt2.X) * (pt3.Y - pt4.Y);
}
//------------------------------------------------------------------------------
inline bool IsHorizontal(TEdge &e) { return e.Dx == HORIZONTAL; }
//------------------------------------------------------------------------------
inline double GetDx(const IntPoint pt1, const IntPoint pt2) {
return (pt1.Y == pt2.Y) ? HORIZONTAL
: (double)(pt2.X - pt1.X) / (pt2.Y - pt1.Y);
}
//---------------------------------------------------------------------------
inline void SetDx(TEdge &e) {
cInt dy = (e.Top.Y - e.Bot.Y);
if (dy == 0)
e.Dx = HORIZONTAL;
else
e.Dx = (double)(e.Top.X - e.Bot.X) / dy;
}
//---------------------------------------------------------------------------
inline void SwapSides(TEdge &Edge1, TEdge &Edge2) {
EdgeSide Side = Edge1.Side;
Edge1.Side = Edge2.Side;
Edge2.Side = Side;
}
//------------------------------------------------------------------------------
inline void SwapPolyIndexes(TEdge &Edge1, TEdge &Edge2) {
int OutIdx = Edge1.OutIdx;
Edge1.OutIdx = Edge2.OutIdx;
Edge2.OutIdx = OutIdx;
}
//------------------------------------------------------------------------------
inline cInt TopX(TEdge &edge, const cInt currentY) {
return (currentY == edge.Top.Y)
? edge.Top.X
: edge.Bot.X + Round(edge.Dx * (currentY - edge.Bot.Y));
}
//------------------------------------------------------------------------------
void IntersectPoint(TEdge &Edge1, TEdge &Edge2, IntPoint &ip) {
#ifdef use_xyz
ip.Z = 0;
#endif
double b1, b2;
if (Edge1.Dx == Edge2.Dx) {
ip.Y = Edge1.Curr.Y;
ip.X = TopX(Edge1, ip.Y);
return;
} else if (Edge1.Dx == 0) {
ip.X = Edge1.Bot.X;
if (IsHorizontal(Edge2))
ip.Y = Edge2.Bot.Y;
else {
b2 = Edge2.Bot.Y - (Edge2.Bot.X / Edge2.Dx);
ip.Y = Round(ip.X / Edge2.Dx + b2);
}
} else if (Edge2.Dx == 0) {
ip.X = Edge2.Bot.X;
if (IsHorizontal(Edge1))
ip.Y = Edge1.Bot.Y;
else {
b1 = Edge1.Bot.Y - (Edge1.Bot.X / Edge1.Dx);
ip.Y = Round(ip.X / Edge1.Dx + b1);
}
} else {
b1 = Edge1.Bot.X - Edge1.Bot.Y * Edge1.Dx;
b2 = Edge2.Bot.X - Edge2.Bot.Y * Edge2.Dx;
double q = (b2 - b1) / (Edge1.Dx - Edge2.Dx);
ip.Y = Round(q);
if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx))
ip.X = Round(Edge1.Dx * q + b1);
else
ip.X = Round(Edge2.Dx * q + b2);
}
if (ip.Y < Edge1.Top.Y || ip.Y < Edge2.Top.Y) {
if (Edge1.Top.Y > Edge2.Top.Y)
ip.Y = Edge1.Top.Y;
else
ip.Y = Edge2.Top.Y;
if (std::fabs(Edge1.Dx) < std::fabs(Edge2.Dx))
ip.X = TopX(Edge1, ip.Y);
else
ip.X = TopX(Edge2, ip.Y);
}
// finally, don't allow 'ip' to be BELOW curr.Y (ie bottom of scanbeam) ...
if (ip.Y > Edge1.Curr.Y) {
ip.Y = Edge1.Curr.Y;
// use the more vertical edge to derive X ...
if (std::fabs(Edge1.Dx) > std::fabs(Edge2.Dx))
ip.X = TopX(Edge2, ip.Y);
else
ip.X = TopX(Edge1, ip.Y);
}
}
//------------------------------------------------------------------------------
void ReversePolyPtLinks(OutPt *pp) {
if (!pp)
return;
OutPt *pp1, *pp2;
pp1 = pp;
do {
pp2 = pp1->Next;
pp1->Next = pp1->Prev;
pp1->Prev = pp2;
pp1 = pp2;
} while (pp1 != pp);
}
//------------------------------------------------------------------------------
void DisposeOutPts(OutPt *&pp) {
if (pp == 0)
return;
pp->Prev->Next = 0;
while (pp) {
OutPt *tmpPp = pp;
pp = pp->Next;
delete tmpPp;
}
}
//------------------------------------------------------------------------------
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
std::memset(e, 0, sizeof(TEdge));
e->Next = eNext;
e->Prev = ePrev;
e->Curr = Pt;
e->OutIdx = Unassigned;
}
//------------------------------------------------------------------------------
void InitEdge2(TEdge &e, PolyType Pt) {
if (e.Curr.Y >= e.Next->Curr.Y) {
e.Bot = e.Curr;
e.Top = e.Next->Curr;
} else {
e.Top = e.Curr;
e.Bot = e.Next->Curr;
}
SetDx(e);
e.PolyTyp = Pt;
}
//------------------------------------------------------------------------------
TEdge *RemoveEdge(TEdge *e) {
// removes e from double_linked_list (but without removing from memory)
e->Prev->Next = e->Next;
e->Next->Prev = e->Prev;
TEdge *result = e->Next;
e->Prev = 0; // flag as removed (see ClipperBase.Clear)
return result;
}
//------------------------------------------------------------------------------
inline void ReverseHorizontal(TEdge &e) {
// swap horizontal edges' Top and Bottom x's so they follow the natural
// progression of the bounds - ie so their xbots will align with the
// adjoining lower edge. [Helpful in the ProcessHorizontal() method.]
std::swap(e.Top.X, e.Bot.X);
#ifdef use_xyz
std::swap(e.Top.Z, e.Bot.Z);
#endif
}
//------------------------------------------------------------------------------
void SwapPoints(IntPoint &pt1, IntPoint &pt2) {
IntPoint tmp = pt1;
pt1 = pt2;
pt2 = tmp;
}
//------------------------------------------------------------------------------
bool GetOverlapSegment(IntPoint pt1a, IntPoint pt1b, IntPoint pt2a,
IntPoint pt2b, IntPoint &pt1, IntPoint &pt2) {
// precondition: segments are Collinear.
if (Abs(pt1a.X - pt1b.X) > Abs(pt1a.Y - pt1b.Y)) {
if (pt1a.X > pt1b.X)
SwapPoints(pt1a, pt1b);
if (pt2a.X > pt2b.X)
SwapPoints(pt2a, pt2b);
if (pt1a.X > pt2a.X)
pt1 = pt1a;
else
pt1 = pt2a;
if (pt1b.X < pt2b.X)
pt2 = pt1b;
else
pt2 = pt2b;
return pt1.X < pt2.X;
} else {
if (pt1a.Y < pt1b.Y)
SwapPoints(pt1a, pt1b);
if (pt2a.Y < pt2b.Y)
SwapPoints(pt2a, pt2b);
if (pt1a.Y < pt2a.Y)
pt1 = pt1a;
else
pt1 = pt2a;
if (pt1b.Y > pt2b.Y)
pt2 = pt1b;
else
pt2 = pt2b;
return pt1.Y > pt2.Y;
}
}
//------------------------------------------------------------------------------
bool FirstIsBottomPt(const OutPt *btmPt1, const OutPt *btmPt2) {
OutPt *p = btmPt1->Prev;
while ((p->Pt == btmPt1->Pt) && (p != btmPt1))
p = p->Prev;
double dx1p = std::fabs(GetDx(btmPt1->Pt, p->Pt));
p = btmPt1->Next;
while ((p->Pt == btmPt1->Pt) && (p != btmPt1))
p = p->Next;
double dx1n = std::fabs(GetDx(btmPt1->Pt, p->Pt));
p = btmPt2->Prev;
while ((p->Pt == btmPt2->Pt) && (p != btmPt2))
p = p->Prev;
double dx2p = std::fabs(GetDx(btmPt2->Pt, p->Pt));
p = btmPt2->Next;
while ((p->Pt == btmPt2->Pt) && (p != btmPt2))
p = p->Next;
double dx2n = std::fabs(GetDx(btmPt2->Pt, p->Pt));
if (std::max(dx1p, dx1n) == std::max(dx2p, dx2n) &&
std::min(dx1p, dx1n) == std::min(dx2p, dx2n))
return Area(btmPt1) > 0; // if otherwise identical use orientation
else
return (dx1p >= dx2p && dx1p >= dx2n) || (dx1n >= dx2p && dx1n >= dx2n);
}
//------------------------------------------------------------------------------
OutPt *GetBottomPt(OutPt *pp) {
OutPt *dups = 0;
OutPt *p = pp->Next;
while (p != pp) {
if (p->Pt.Y > pp->Pt.Y) {
pp = p;
dups = 0;
} else if (p->Pt.Y == pp->Pt.Y && p->Pt.X <= pp->Pt.X) {
if (p->Pt.X < pp->Pt.X) {
dups = 0;
pp = p;
} else {
if (p->Next != pp && p->Prev != pp)
dups = p;
}
}
p = p->Next;
}
if (dups) {
// there appears to be at least 2 vertices at BottomPt so ...
while (dups != p) {
if (!FirstIsBottomPt(p, dups))
pp = dups;
dups = dups->Next;
while (dups->Pt != pp->Pt)
dups = dups->Next;
}
}
return pp;
}
//------------------------------------------------------------------------------
bool Pt2IsBetweenPt1AndPt3(const IntPoint pt1, const IntPoint pt2,
const IntPoint pt3) {
if ((pt1 == pt3) || (pt1 == pt2) || (pt3 == pt2))
return false;
else if (pt1.X != pt3.X)
return (pt2.X > pt1.X) == (pt2.X < pt3.X);
else
return (pt2.Y > pt1.Y) == (pt2.Y < pt3.Y);
}
//------------------------------------------------------------------------------
bool HorzSegmentsOverlap(cInt seg1a, cInt seg1b, cInt seg2a, cInt seg2b) {
if (seg1a > seg1b)
std::swap(seg1a, seg1b);
if (seg2a > seg2b)
std::swap(seg2a, seg2b);
return (seg1a < seg2b) && (seg2a < seg1b);
}
//------------------------------------------------------------------------------
// ClipperBase class methods ...
//------------------------------------------------------------------------------
ClipperBase::ClipperBase() // constructor
{
m_CurrentLM = m_MinimaList.begin(); // begin() == end() here
m_UseFullRange = false;
}
//------------------------------------------------------------------------------
ClipperBase::~ClipperBase() // destructor
{
Clear();
}
//------------------------------------------------------------------------------
void RangeTest(const IntPoint &Pt, bool &useFullRange) {
if (useFullRange) {
if (Pt.X > hiRange || Pt.Y > hiRange || -Pt.X > hiRange || -Pt.Y > hiRange)
throw clipperException("Coordinate outside allowed range");
} else if (Pt.X > loRange || Pt.Y > loRange || -Pt.X > loRange ||
-Pt.Y > loRange) {
useFullRange = true;
RangeTest(Pt, useFullRange);
}
}
//------------------------------------------------------------------------------
TEdge *FindNextLocMin(TEdge *E) {
for (;;) {
while (E->Bot != E->Prev->Bot || E->Curr == E->Top)
E = E->Next;
if (!IsHorizontal(*E) && !IsHorizontal(*E->Prev))
break;
while (IsHorizontal(*E->Prev))
E = E->Prev;
TEdge *E2 = E;
while (IsHorizontal(*E))
E = E->Next;
if (E->Top.Y == E->Prev->Bot.Y)
continue; // ie just an intermediate horz.
if (E2->Prev->Bot.X < E->Bot.X)
E = E2;
break;
}
return E;
}
//------------------------------------------------------------------------------
TEdge *ClipperBase::ProcessBound(TEdge *E, bool NextIsForward) {
TEdge *Result = E;
TEdge *Horz = 0;
if (E->OutIdx == Skip) {
// if edges still remain in the current bound beyond the skip edge then
// create another LocMin and call ProcessBound once more
if (NextIsForward) {
while (E->Top.Y == E->Next->Bot.Y)
E = E->Next;
// don't include top horizontals when parsing a bound a second time,
// they will be contained in the opposite bound ...
while (E != Result && IsHorizontal(*E))
E = E->Prev;
} else {
while (E->Top.Y == E->Prev->Bot.Y)
E = E->Prev;
while (E != Result && IsHorizontal(*E))
E = E->Next;
}
if (E == Result) {
if (NextIsForward)
Result = E->Next;
else
Result = E->Prev;
} else {
// there are more edges in the bound beyond result starting with E
if (NextIsForward)
E = Result->Next;
else
E = Result->Prev;
MinimaList::value_type locMin;
locMin.Y = E->Bot.Y;
locMin.LeftBound = 0;
locMin.RightBound = E;
E->WindDelta = 0;
Result = ProcessBound(E, NextIsForward);
m_MinimaList.push_back(locMin);
}
return Result;
}
TEdge *EStart;
if (IsHorizontal(*E)) {
// We need to be careful with open paths because this may not be a
// true local minima (ie E may be following a skip edge).
// Also, consecutive horz. edges may start heading left before going right.
if (NextIsForward)
EStart = E->Prev;
else
EStart = E->Next;
if (IsHorizontal(*EStart)) // ie an adjoining horizontal skip edge
{
if (EStart->Bot.X != E->Bot.X && EStart->Top.X != E->Bot.X)
ReverseHorizontal(*E);
} else if (EStart->Bot.X != E->Bot.X)
ReverseHorizontal(*E);
}
EStart = E;
if (NextIsForward) {
while (Result->Top.Y == Result->Next->Bot.Y && Result->Next->OutIdx != Skip)
Result = Result->Next;
if (IsHorizontal(*Result) && Result->Next->OutIdx != Skip) {
// nb: at the top of a bound, horizontals are added to the bound
// only when the preceding edge attaches to the horizontal's left vertex
// unless a Skip edge is encountered when that becomes the top divide
Horz = Result;
while (IsHorizontal(*Horz->Prev))
Horz = Horz->Prev;
if (Horz->Prev->Top.X > Result->Next->Top.X)
Result = Horz->Prev;
}
while (E != Result) {
E->NextInLML = E->Next;
if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X)
ReverseHorizontal(*E);
E = E->Next;
}
if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Prev->Top.X)
ReverseHorizontal(*E);
Result = Result->Next; // move to the edge just beyond current bound
} else {
while (Result->Top.Y == Result->Prev->Bot.Y && Result->Prev->OutIdx != Skip)
Result = Result->Prev;
if (IsHorizontal(*Result) && Result->Prev->OutIdx != Skip) {
Horz = Result;
while (IsHorizontal(*Horz->Next))
Horz = Horz->Next;
if (Horz->Next->Top.X == Result->Prev->Top.X ||
Horz->Next->Top.X > Result->Prev->Top.X)
Result = Horz->Next;
}
while (E != Result) {
E->NextInLML = E->Prev;
if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X)
ReverseHorizontal(*E);
E = E->Prev;
}
if (IsHorizontal(*E) && E != EStart && E->Bot.X != E->Next->Top.X)
ReverseHorizontal(*E);
Result = Result->Prev; // move to the edge just beyond current bound
}
return Result;
}
//------------------------------------------------------------------------------
bool ClipperBase::AddPath(const Path &pg, PolyType PolyTyp, bool Closed) {
#ifdef use_lines
if (!Closed && PolyTyp == ptClip)
throw clipperException("AddPath: Open paths must be subject.");
#else
if (!Closed)
throw clipperException("AddPath: Open paths have been disabled.");
#endif
int highI = (int)pg.size() - 1;
if (Closed)
while (highI > 0 && (pg[highI] == pg[0]))
--highI;
while (highI > 0 && (pg[highI] == pg[highI - 1]))
--highI;
if ((Closed && highI < 2) || (!Closed && highI < 1))
return false;
// create a new edge array ...
TEdge *edges = new TEdge[highI + 1];
bool IsFlat = true;
// 1. Basic (first) edge initialization ...
try {
edges[1].Curr = pg[1];
RangeTest(pg[0], m_UseFullRange);
RangeTest(pg[highI], m_UseFullRange);
InitEdge(&edges[0], &edges[1], &edges[highI], pg[0]);
InitEdge(&edges[highI], &edges[0], &edges[highI - 1], pg[highI]);
for (int i = highI - 1; i >= 1; --i) {
RangeTest(pg[i], m_UseFullRange);
InitEdge(&edges[i], &edges[i + 1], &edges[i - 1], pg[i]);
}
} catch (...) {
delete[] edges;
throw; // range test fails
}
TEdge *eStart = &edges[0];
// 2. Remove duplicate vertices, and (when closed) collinear edges ...
TEdge *E = eStart, *eLoopStop = eStart;
for (;;) {
// nb: allows matching start and end points when not Closed ...
if (E->Curr == E->Next->Curr && (Closed || E->Next != eStart)) {
if (E == E->Next)
break;
if (E == eStart)
eStart = E->Next;
E = RemoveEdge(E);
eLoopStop = E;
continue;
}
if (E->Prev == E->Next)
break; // only two vertices
else if (Closed && SlopesEqual(E->Prev->Curr, E->Curr, E->Next->Curr,
m_UseFullRange) &&
(!m_PreserveCollinear ||
!Pt2IsBetweenPt1AndPt3(E->Prev->Curr, E->Curr, E->Next->Curr))) {
// Collinear edges are allowed for open paths but in closed paths
// the default is to merge adjacent collinear edges into a single edge.
// However, if the PreserveCollinear property is enabled, only overlapping
// collinear edges (ie spikes) will be removed from closed paths.
if (E == eStart)
eStart = E->Next;
E = RemoveEdge(E);
E = E->Prev;
eLoopStop = E;
continue;
}
E = E->Next;
if ((E == eLoopStop) || (!Closed && E->Next == eStart))
break;
}
if ((!Closed && (E == E->Next)) || (Closed && (E->Prev == E->Next))) {
delete[] edges;
return false;
}
if (!Closed) {
m_HasOpenPaths = true;
eStart->Prev->OutIdx = Skip;
}
// 3. Do second stage of edge initialization ...
E = eStart;
do {
InitEdge2(*E, PolyTyp);
E = E->Next;
if (IsFlat && E->Curr.Y != eStart->Curr.Y)
IsFlat = false;
} while (E != eStart);
// 4. Finally, add edge bounds to LocalMinima list ...
// Totally flat paths must be handled differently when adding them
// to LocalMinima list to avoid endless loops etc ...
if (IsFlat) {
if (Closed) {
delete[] edges;
return false;
}
E->Prev->OutIdx = Skip;
MinimaList::value_type locMin;
locMin.Y = E->Bot.Y;
locMin.LeftBound = 0;
locMin.RightBound = E;
locMin.RightBound->Side = esRight;
locMin.RightBound->WindDelta = 0;
for (;;) {
if (E->Bot.X != E->Prev->Top.X)
ReverseHorizontal(*E);
if (E->Next->OutIdx == Skip)
break;
E->NextInLML = E->Next;
E = E->Next;
}
m_MinimaList.push_back(locMin);
m_edges.push_back(edges);
return true;
}
m_edges.push_back(edges);
bool leftBoundIsForward;
TEdge *EMin = 0;
// workaround to avoid an endless loop in the while loop below when
// open paths have matching start and end points ...
if (E->Prev->Bot == E->Prev->Top)
E = E->Next;
for (;;) {
E = FindNextLocMin(E);
if (E == EMin)
break;
else if (!EMin)
EMin = E;
// E and E.Prev now share a local minima (left aligned if horizontal).
// Compare their slopes to find which starts which bound ...
MinimaList::value_type locMin;
locMin.Y = E->Bot.Y;
if (E->Dx < E->Prev->Dx) {
locMin.LeftBound = E->Prev;
locMin.RightBound = E;
leftBoundIsForward = false; // Q.nextInLML = Q.prev
} else {
locMin.LeftBound = E;
locMin.RightBound = E->Prev;
leftBoundIsForward = true; // Q.nextInLML = Q.next
}
if (!Closed)
locMin.LeftBound->WindDelta = 0;
else if (locMin.LeftBound->Next == locMin.RightBound)
locMin.LeftBound->WindDelta = -1;
else
locMin.LeftBound->WindDelta = 1;
locMin.RightBound->WindDelta = -locMin.LeftBound->WindDelta;
E = ProcessBound(locMin.LeftBound, leftBoundIsForward);
if (E->OutIdx == Skip)
E = ProcessBound(E, leftBoundIsForward);
TEdge *E2 = ProcessBound(locMin.RightBound, !leftBoundIsForward);
if (E2->OutIdx == Skip)
E2 = ProcessBound(E2, !leftBoundIsForward);
if (locMin.LeftBound->OutIdx == Skip)
locMin.LeftBound = 0;
else if (locMin.RightBound->OutIdx == Skip)
locMin.RightBound = 0;
m_MinimaList.push_back(locMin);
if (!leftBoundIsForward)
E = E2;
}
return true;
}
//------------------------------------------------------------------------------
bool ClipperBase::AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed) {
bool result = false;
for (Paths::size_type i = 0; i < ppg.size(); ++i)
if (AddPath(ppg[i], PolyTyp, Closed))
result = true;
return result;
}
//------------------------------------------------------------------------------
void ClipperBase::Clear() {
DisposeLocalMinimaList();
for (EdgeList::size_type i = 0; i < m_edges.size(); ++i) {
TEdge *edges = m_edges[i];
delete[] edges;
}
m_edges.clear();
m_UseFullRange = false;
m_HasOpenPaths = false;
}
//------------------------------------------------------------------------------
void ClipperBase::Reset() {
m_CurrentLM = m_MinimaList.begin();
if (m_CurrentLM == m_MinimaList.end())
return; // ie nothing to process
std::sort(m_MinimaList.begin(), m_MinimaList.end(), LocMinSorter());
m_Scanbeam = ScanbeamList(); // clears/resets priority_queue
// reset all edges ...
for (MinimaList::iterator lm = m_MinimaList.begin(); lm != m_MinimaList.end();
++lm) {
InsertScanbeam(lm->Y);
TEdge *e = lm->LeftBound;
if (e) {
e->Curr = e->Bot;
e->Side = esLeft;
e->OutIdx = Unassigned;
}
e = lm->RightBound;
if (e) {
e->Curr = e->Bot;
e->Side = esRight;
e->OutIdx = Unassigned;
}
}
m_ActiveEdges = 0;
m_CurrentLM = m_MinimaList.begin();
}
//------------------------------------------------------------------------------
void ClipperBase::DisposeLocalMinimaList() {
m_MinimaList.clear();
m_CurrentLM = m_MinimaList.begin();
}
//------------------------------------------------------------------------------
bool ClipperBase::PopLocalMinima(cInt Y, const LocalMinimum *&locMin) {
if (m_CurrentLM == m_MinimaList.end() || (*m_CurrentLM).Y != Y)
return false;
locMin = &(*m_CurrentLM);
++m_CurrentLM;
return true;
}
//------------------------------------------------------------------------------
IntRect ClipperBase::GetBounds() {
IntRect result;
MinimaList::iterator lm = m_MinimaList.begin();
if (lm == m_MinimaList.end()) {
result.left = result.top = result.right = result.bottom = 0;
return result;
}
result.left = lm->LeftBound->Bot.X;
result.top = lm->LeftBound->Bot.Y;
result.right = lm->LeftBound->Bot.X;
result.bottom = lm->LeftBound->Bot.Y;
while (lm != m_MinimaList.end()) {
// todo - needs fixing for open paths
result.bottom = std::max(result.bottom, lm->LeftBound->Bot.Y);
TEdge *e = lm->LeftBound;
for (;;) {
TEdge *bottomE = e;
while (e->NextInLML) {
if (e->Bot.X < result.left)
result.left = e->Bot.X;
if (e->Bot.X > result.right)
result.right = e->Bot.X;
e = e->NextInLML;
}
result.left = std::min(result.left, e->Bot.X);
result.right = std::max(result.right, e->Bot.X);
result.left = std::min(result.left, e->Top.X);
result.right = std::max(result.right, e->Top.X);
result.top = std::min(result.top, e->Top.Y);
if (bottomE == lm->LeftBound)
e = lm->RightBound;
else
break;
}
++lm;
}
return result;
}
//------------------------------------------------------------------------------
void ClipperBase::InsertScanbeam(const cInt Y) { m_Scanbeam.push(Y); }
//------------------------------------------------------------------------------
bool ClipperBase::PopScanbeam(cInt &Y) {
if (m_Scanbeam.empty())
return false;
Y = m_Scanbeam.top();
m_Scanbeam.pop();
while (!m_Scanbeam.empty() && Y == m_Scanbeam.top()) {
m_Scanbeam.pop();
} // Pop duplicates.
return true;
}
//------------------------------------------------------------------------------
void ClipperBase::DisposeAllOutRecs() {
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i)
DisposeOutRec(i);
m_PolyOuts.clear();
}
//------------------------------------------------------------------------------
void ClipperBase::DisposeOutRec(PolyOutList::size_type index) {
OutRec *outRec = m_PolyOuts[index];
if (outRec->Pts)
DisposeOutPts(outRec->Pts);
delete outRec;
m_PolyOuts[index] = 0;
}
//------------------------------------------------------------------------------
void ClipperBase::DeleteFromAEL(TEdge *e) {
TEdge *AelPrev = e->PrevInAEL;
TEdge *AelNext = e->NextInAEL;
if (!AelPrev && !AelNext && (e != m_ActiveEdges))
return; // already deleted
if (AelPrev)
AelPrev->NextInAEL = AelNext;
else
m_ActiveEdges = AelNext;
if (AelNext)
AelNext->PrevInAEL = AelPrev;
e->NextInAEL = 0;
e->PrevInAEL = 0;
}
//------------------------------------------------------------------------------
OutRec *ClipperBase::CreateOutRec() {
OutRec *result = new OutRec;
result->IsHole = false;
result->IsOpen = false;
result->FirstLeft = 0;
result->Pts = 0;
result->BottomPt = 0;
result->PolyNd = 0;
m_PolyOuts.push_back(result);
result->Idx = (int)m_PolyOuts.size() - 1;
return result;
}
//------------------------------------------------------------------------------
void ClipperBase::SwapPositionsInAEL(TEdge *Edge1, TEdge *Edge2) {
// check that one or other edge hasn't already been removed from AEL ...
if (Edge1->NextInAEL == Edge1->PrevInAEL ||
Edge2->NextInAEL == Edge2->PrevInAEL)
return;
if (Edge1->NextInAEL == Edge2) {
TEdge *Next = Edge2->NextInAEL;
if (Next)
Next->PrevInAEL = Edge1;
TEdge *Prev = Edge1->PrevInAEL;
if (Prev)
Prev->NextInAEL = Edge2;
Edge2->PrevInAEL = Prev;
Edge2->NextInAEL = Edge1;
Edge1->PrevInAEL = Edge2;
Edge1->NextInAEL = Next;
} else if (Edge2->NextInAEL == Edge1) {
TEdge *Next = Edge1->NextInAEL;
if (Next)
Next->PrevInAEL = Edge2;
TEdge *Prev = Edge2->PrevInAEL;
if (Prev)
Prev->NextInAEL = Edge1;
Edge1->PrevInAEL = Prev;
Edge1->NextInAEL = Edge2;
Edge2->PrevInAEL = Edge1;
Edge2->NextInAEL = Next;
} else {
TEdge *Next = Edge1->NextInAEL;
TEdge *Prev = Edge1->PrevInAEL;
Edge1->NextInAEL = Edge2->NextInAEL;
if (Edge1->NextInAEL)
Edge1->NextInAEL->PrevInAEL = Edge1;
Edge1->PrevInAEL = Edge2->PrevInAEL;
if (Edge1->PrevInAEL)
Edge1->PrevInAEL->NextInAEL = Edge1;
Edge2->NextInAEL = Next;
if (Edge2->NextInAEL)
Edge2->NextInAEL->PrevInAEL = Edge2;
Edge2->PrevInAEL = Prev;
if (Edge2->PrevInAEL)
Edge2->PrevInAEL->NextInAEL = Edge2;
}
if (!Edge1->PrevInAEL)
m_ActiveEdges = Edge1;
else if (!Edge2->PrevInAEL)
m_ActiveEdges = Edge2;
}
//------------------------------------------------------------------------------
void ClipperBase::UpdateEdgeIntoAEL(TEdge *&e) {
if (!e->NextInLML)
throw clipperException("UpdateEdgeIntoAEL: invalid call");
e->NextInLML->OutIdx = e->OutIdx;
TEdge *AelPrev = e->PrevInAEL;
TEdge *AelNext = e->NextInAEL;
if (AelPrev)
AelPrev->NextInAEL = e->NextInLML;
else
m_ActiveEdges = e->NextInLML;
if (AelNext)
AelNext->PrevInAEL = e->NextInLML;
e->NextInLML->Side = e->Side;
e->NextInLML->WindDelta = e->WindDelta;
e->NextInLML->WindCnt = e->WindCnt;
e->NextInLML->WindCnt2 = e->WindCnt2;
e = e->NextInLML;
e->Curr = e->Bot;
e->PrevInAEL = AelPrev;
e->NextInAEL = AelNext;
if (!IsHorizontal(*e))
InsertScanbeam(e->Top.Y);
}
//------------------------------------------------------------------------------
bool ClipperBase::LocalMinimaPending() {
return (m_CurrentLM != m_MinimaList.end());
}
//------------------------------------------------------------------------------
// TClipper methods ...
//------------------------------------------------------------------------------
Clipper::Clipper(int initOptions)
: ClipperBase() // constructor
{
m_ExecuteLocked = false;
m_UseFullRange = false;
m_ReverseOutput = ((initOptions & ioReverseSolution) != 0);
m_StrictSimple = ((initOptions & ioStrictlySimple) != 0);
m_PreserveCollinear = ((initOptions & ioPreserveCollinear) != 0);
m_HasOpenPaths = false;
#ifdef use_xyz
m_ZFill = 0;
#endif
}
//------------------------------------------------------------------------------
#ifdef use_xyz
void Clipper::ZFillFunction(ZFillCallback zFillFunc) { m_ZFill = zFillFunc; }
//------------------------------------------------------------------------------
#endif
bool Clipper::Execute(ClipType clipType, Paths &solution,
PolyFillType fillType) {
return Execute(clipType, solution, fillType, fillType);
}
//------------------------------------------------------------------------------
bool Clipper::Execute(ClipType clipType, PolyTree &polytree,
PolyFillType fillType) {
return Execute(clipType, polytree, fillType, fillType);
}
//------------------------------------------------------------------------------
bool Clipper::Execute(ClipType clipType, Paths &solution,
PolyFillType subjFillType, PolyFillType clipFillType) {
if (m_ExecuteLocked)
return false;
if (m_HasOpenPaths)
throw clipperException(
"Error: PolyTree struct is needed for open path clipping.");
m_ExecuteLocked = true;
solution.resize(0);
m_SubjFillType = subjFillType;
m_ClipFillType = clipFillType;
m_ClipType = clipType;
m_UsingPolyTree = false;
bool succeeded = ExecuteInternal();
if (succeeded)
BuildResult(solution);
DisposeAllOutRecs();
m_ExecuteLocked = false;
return succeeded;
}
//------------------------------------------------------------------------------
bool Clipper::Execute(ClipType clipType, PolyTree &polytree,
PolyFillType subjFillType, PolyFillType clipFillType) {
if (m_ExecuteLocked)
return false;
m_ExecuteLocked = true;
m_SubjFillType = subjFillType;
m_ClipFillType = clipFillType;
m_ClipType = clipType;
m_UsingPolyTree = true;
bool succeeded = ExecuteInternal();
if (succeeded)
BuildResult2(polytree);
DisposeAllOutRecs();
m_ExecuteLocked = false;
return succeeded;
}
//------------------------------------------------------------------------------
void Clipper::FixHoleLinkage(OutRec &outrec) {
// skip OutRecs that (a) contain outermost polygons or
//(b) already have the correct owner/child linkage ...
if (!outrec.FirstLeft ||
(outrec.IsHole != outrec.FirstLeft->IsHole && outrec.FirstLeft->Pts))
return;
OutRec *orfl = outrec.FirstLeft;
while (orfl && ((orfl->IsHole == outrec.IsHole) || !orfl->Pts))
orfl = orfl->FirstLeft;
outrec.FirstLeft = orfl;
}
//------------------------------------------------------------------------------
bool Clipper::ExecuteInternal() {
bool succeeded = true;
try {
Reset();
m_Maxima = MaximaList();
m_SortedEdges = 0;
succeeded = true;
cInt botY, topY;
if (!PopScanbeam(botY))
return false;
InsertLocalMinimaIntoAEL(botY);
while (PopScanbeam(topY) || LocalMinimaPending()) {
ProcessHorizontals();
ClearGhostJoins();
if (!ProcessIntersections(topY)) {
succeeded = false;
break;
}
ProcessEdgesAtTopOfScanbeam(topY);
botY = topY;
InsertLocalMinimaIntoAEL(botY);
}
} catch (...) {
succeeded = false;
}
if (succeeded) {
// fix orientations ...
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
OutRec *outRec = m_PolyOuts[i];
if (!outRec->Pts || outRec->IsOpen)
continue;
if ((outRec->IsHole ^ m_ReverseOutput) == (Area(*outRec) > 0))
ReversePolyPtLinks(outRec->Pts);
}
if (!m_Joins.empty())
JoinCommonEdges();
// unfortunately FixupOutPolygon() must be done after JoinCommonEdges()
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
OutRec *outRec = m_PolyOuts[i];
if (!outRec->Pts)
continue;
if (outRec->IsOpen)
FixupOutPolyline(*outRec);
else
FixupOutPolygon(*outRec);
}
if (m_StrictSimple)
DoSimplePolygons();
}
ClearJoins();
ClearGhostJoins();
return succeeded;
}
//------------------------------------------------------------------------------
void Clipper::SetWindingCount(TEdge &edge) {
TEdge *e = edge.PrevInAEL;
// find the edge of the same polytype that immediately preceeds 'edge' in AEL
while (e && ((e->PolyTyp != edge.PolyTyp) || (e->WindDelta == 0)))
e = e->PrevInAEL;
if (!e) {
if (edge.WindDelta == 0) {
PolyFillType pft =
(edge.PolyTyp == ptSubject ? m_SubjFillType : m_ClipFillType);
edge.WindCnt = (pft == pftNegative ? -1 : 1);
} else
edge.WindCnt = edge.WindDelta;
edge.WindCnt2 = 0;
e = m_ActiveEdges; // ie get ready to calc WindCnt2
} else if (edge.WindDelta == 0 && m_ClipType != ctUnion) {
edge.WindCnt = 1;
edge.WindCnt2 = e->WindCnt2;
e = e->NextInAEL; // ie get ready to calc WindCnt2
} else if (IsEvenOddFillType(edge)) {
// EvenOdd filling ...
if (edge.WindDelta == 0) {
// are we inside a subj polygon ...
bool Inside = true;
TEdge *e2 = e->PrevInAEL;
while (e2) {
if (e2->PolyTyp == e->PolyTyp && e2->WindDelta != 0)
Inside = !Inside;
e2 = e2->PrevInAEL;
}
edge.WindCnt = (Inside ? 0 : 1);
} else {
edge.WindCnt = edge.WindDelta;
}
edge.WindCnt2 = e->WindCnt2;
e = e->NextInAEL; // ie get ready to calc WindCnt2
} else {
// nonZero, Positive or Negative filling ...
if (e->WindCnt * e->WindDelta < 0) {
// prev edge is 'decreasing' WindCount (WC) toward zero
// so we're outside the previous polygon ...
if (Abs(e->WindCnt) > 1) {
// outside prev poly but still inside another.
// when reversing direction of prev poly use the same WC
if (e->WindDelta * edge.WindDelta < 0)
edge.WindCnt = e->WindCnt;
// otherwise continue to 'decrease' WC ...
else
edge.WindCnt = e->WindCnt + edge.WindDelta;
} else
// now outside all polys of same polytype so set own WC ...
edge.WindCnt = (edge.WindDelta == 0 ? 1 : edge.WindDelta);
} else {
// prev edge is 'increasing' WindCount (WC) away from zero
// so we're inside the previous polygon ...
if (edge.WindDelta == 0)
edge.WindCnt = (e->WindCnt < 0 ? e->WindCnt - 1 : e->WindCnt + 1);
// if wind direction is reversing prev then use same WC
else if (e->WindDelta * edge.WindDelta < 0)
edge.WindCnt = e->WindCnt;
// otherwise add to WC ...
else
edge.WindCnt = e->WindCnt + edge.WindDelta;
}
edge.WindCnt2 = e->WindCnt2;
e = e->NextInAEL; // ie get ready to calc WindCnt2
}
// update WindCnt2 ...
if (IsEvenOddAltFillType(edge)) {
// EvenOdd filling ...
while (e != &edge) {
if (e->WindDelta != 0)
edge.WindCnt2 = (edge.WindCnt2 == 0 ? 1 : 0);
e = e->NextInAEL;
}
} else {
// nonZero, Positive or Negative filling ...
while (e != &edge) {
edge.WindCnt2 += e->WindDelta;
e = e->NextInAEL;
}
}
}
//------------------------------------------------------------------------------
bool Clipper::IsEvenOddFillType(const TEdge &edge) const {
if (edge.PolyTyp == ptSubject)
return m_SubjFillType == pftEvenOdd;
else
return m_ClipFillType == pftEvenOdd;
}
//------------------------------------------------------------------------------
bool Clipper::IsEvenOddAltFillType(const TEdge &edge) const {
if (edge.PolyTyp == ptSubject)
return m_ClipFillType == pftEvenOdd;
else
return m_SubjFillType == pftEvenOdd;
}
//------------------------------------------------------------------------------
bool Clipper::IsContributing(const TEdge &edge) const {
PolyFillType pft, pft2;
if (edge.PolyTyp == ptSubject) {
pft = m_SubjFillType;
pft2 = m_ClipFillType;
} else {
pft = m_ClipFillType;
pft2 = m_SubjFillType;
}
switch (pft) {
case pftEvenOdd:
// return false if a subj line has been flagged as inside a subj polygon
if (edge.WindDelta == 0 && edge.WindCnt != 1)
return false;
break;
case pftNonZero:
if (Abs(edge.WindCnt) != 1)
return false;
break;
case pftPositive:
if (edge.WindCnt != 1)
return false;
break;
default: // pftNegative
if (edge.WindCnt != -1)
return false;
}
switch (m_ClipType) {
case ctIntersection:
switch (pft2) {
case pftEvenOdd:
case pftNonZero:
return (edge.WindCnt2 != 0);
case pftPositive:
return (edge.WindCnt2 > 0);
default:
return (edge.WindCnt2 < 0);
}
break;
case ctUnion:
switch (pft2) {
case pftEvenOdd:
case pftNonZero:
return (edge.WindCnt2 == 0);
case pftPositive:
return (edge.WindCnt2 <= 0);
default:
return (edge.WindCnt2 >= 0);
}
break;
case ctDifference:
if (edge.PolyTyp == ptSubject)
switch (pft2) {
case pftEvenOdd:
case pftNonZero:
return (edge.WindCnt2 == 0);
case pftPositive:
return (edge.WindCnt2 <= 0);
default:
return (edge.WindCnt2 >= 0);
}
else
switch (pft2) {
case pftEvenOdd:
case pftNonZero:
return (edge.WindCnt2 != 0);
case pftPositive:
return (edge.WindCnt2 > 0);
default:
return (edge.WindCnt2 < 0);
}
break;
case ctXor:
if (edge.WindDelta == 0) // XOr always contributing unless open
switch (pft2) {
case pftEvenOdd:
case pftNonZero:
return (edge.WindCnt2 == 0);
case pftPositive:
return (edge.WindCnt2 <= 0);
default:
return (edge.WindCnt2 >= 0);
}
else
return true;
break;
default:
return true;
}
}
//------------------------------------------------------------------------------
OutPt *Clipper::AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) {
OutPt *result;
TEdge *e, *prevE;
if (IsHorizontal(*e2) || (e1->Dx > e2->Dx)) {
result = AddOutPt(e1, Pt);
e2->OutIdx = e1->OutIdx;
e1->Side = esLeft;
e2->Side = esRight;
e = e1;
if (e->PrevInAEL == e2)
prevE = e2->PrevInAEL;
else
prevE = e->PrevInAEL;
} else {
result = AddOutPt(e2, Pt);
e1->OutIdx = e2->OutIdx;
e1->Side = esRight;
e2->Side = esLeft;
e = e2;
if (e->PrevInAEL == e1)
prevE = e1->PrevInAEL;
else
prevE = e->PrevInAEL;
}
if (prevE && prevE->OutIdx >= 0 && prevE->Top.Y < Pt.Y && e->Top.Y < Pt.Y) {
cInt xPrev = TopX(*prevE, Pt.Y);
cInt xE = TopX(*e, Pt.Y);
if (xPrev == xE && (e->WindDelta != 0) && (prevE->WindDelta != 0) &&
SlopesEqual(IntPoint(xPrev, Pt.Y), prevE->Top, IntPoint(xE, Pt.Y),
e->Top, m_UseFullRange)) {
OutPt *outPt = AddOutPt(prevE, Pt);
AddJoin(result, outPt, e->Top);
}
}
return result;
}
//------------------------------------------------------------------------------
void Clipper::AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &Pt) {
AddOutPt(e1, Pt);
if (e2->WindDelta == 0)
AddOutPt(e2, Pt);
if (e1->OutIdx == e2->OutIdx) {
e1->OutIdx = Unassigned;
e2->OutIdx = Unassigned;
} else if (e1->OutIdx < e2->OutIdx)
AppendPolygon(e1, e2);
else
AppendPolygon(e2, e1);
}
//------------------------------------------------------------------------------
void Clipper::AddEdgeToSEL(TEdge *edge) {
// SEL pointers in PEdge are reused to build a list of horizontal edges.
// However, we don't need to worry about order with horizontal edge
// processing.
if (!m_SortedEdges) {
m_SortedEdges = edge;
edge->PrevInSEL = 0;
edge->NextInSEL = 0;
} else {
edge->NextInSEL = m_SortedEdges;
edge->PrevInSEL = 0;
m_SortedEdges->PrevInSEL = edge;
m_SortedEdges = edge;
}
}
//------------------------------------------------------------------------------
bool Clipper::PopEdgeFromSEL(TEdge *&edge) {
if (!m_SortedEdges)
return false;
edge = m_SortedEdges;
DeleteFromSEL(m_SortedEdges);
return true;
}
//------------------------------------------------------------------------------
void Clipper::CopyAELToSEL() {
TEdge *e = m_ActiveEdges;
m_SortedEdges = e;
while (e) {
e->PrevInSEL = e->PrevInAEL;
e->NextInSEL = e->NextInAEL;
e = e->NextInAEL;
}
}
//------------------------------------------------------------------------------
void Clipper::AddJoin(OutPt *op1, OutPt *op2, const IntPoint OffPt) {
Join *j = new Join;
j->OutPt1 = op1;
j->OutPt2 = op2;
j->OffPt = OffPt;
m_Joins.push_back(j);
}
//------------------------------------------------------------------------------
void Clipper::ClearJoins() {
for (JoinList::size_type i = 0; i < m_Joins.size(); i++)
delete m_Joins[i];
m_Joins.resize(0);
}
//------------------------------------------------------------------------------
void Clipper::ClearGhostJoins() {
for (JoinList::size_type i = 0; i < m_GhostJoins.size(); i++)
delete m_GhostJoins[i];
m_GhostJoins.resize(0);
}
//------------------------------------------------------------------------------
void Clipper::AddGhostJoin(OutPt *op, const IntPoint OffPt) {
Join *j = new Join;
j->OutPt1 = op;
j->OutPt2 = 0;
j->OffPt = OffPt;
m_GhostJoins.push_back(j);
}
//------------------------------------------------------------------------------
void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
const LocalMinimum *lm;
while (PopLocalMinima(botY, lm)) {
TEdge *lb = lm->LeftBound;
TEdge *rb = lm->RightBound;
OutPt *Op1 = 0;
if (!lb) {
// nb: don't insert LB into either AEL or SEL
InsertEdgeIntoAEL(rb, 0);
SetWindingCount(*rb);
if (IsContributing(*rb))
Op1 = AddOutPt(rb, rb->Bot);
} else if (!rb) {
InsertEdgeIntoAEL(lb, 0);
SetWindingCount(*lb);
if (IsContributing(*lb))
Op1 = AddOutPt(lb, lb->Bot);
InsertScanbeam(lb->Top.Y);
} else {
InsertEdgeIntoAEL(lb, 0);
InsertEdgeIntoAEL(rb, lb);
SetWindingCount(*lb);
rb->WindCnt = lb->WindCnt;
rb->WindCnt2 = lb->WindCnt2;
if (IsContributing(*lb))
Op1 = AddLocalMinPoly(lb, rb, lb->Bot);
InsertScanbeam(lb->Top.Y);
}
if (rb) {
if (IsHorizontal(*rb)) {
AddEdgeToSEL(rb);
if (rb->NextInLML)
InsertScanbeam(rb->NextInLML->Top.Y);
} else
InsertScanbeam(rb->Top.Y);
}
if (!lb || !rb)
continue;
// if any output polygons share an edge, they'll need joining later ...
if (Op1 && IsHorizontal(*rb) && m_GhostJoins.size() > 0 &&
(rb->WindDelta != 0)) {
for (JoinList::size_type i = 0; i < m_GhostJoins.size(); ++i) {
Join *jr = m_GhostJoins[i];
// if the horizontal Rb and a 'ghost' horizontal overlap, then convert
// the 'ghost' join to a real join ready for later ...
if (HorzSegmentsOverlap(jr->OutPt1->Pt.X, jr->OffPt.X, rb->Bot.X,
rb->Top.X))
AddJoin(jr->OutPt1, Op1, jr->OffPt);
}
}
if (lb->OutIdx >= 0 && lb->PrevInAEL &&
lb->PrevInAEL->Curr.X == lb->Bot.X && lb->PrevInAEL->OutIdx >= 0 &&
SlopesEqual(lb->PrevInAEL->Bot, lb->PrevInAEL->Top, lb->Curr, lb->Top,
m_UseFullRange) &&
(lb->WindDelta != 0) && (lb->PrevInAEL->WindDelta != 0)) {
OutPt *Op2 = AddOutPt(lb->PrevInAEL, lb->Bot);
AddJoin(Op1, Op2, lb->Top);
}
if (lb->NextInAEL != rb) {
if (rb->OutIdx >= 0 && rb->PrevInAEL->OutIdx >= 0 &&
SlopesEqual(rb->PrevInAEL->Curr, rb->PrevInAEL->Top, rb->Curr,
rb->Top, m_UseFullRange) &&
(rb->WindDelta != 0) && (rb->PrevInAEL->WindDelta != 0)) {
OutPt *Op2 = AddOutPt(rb->PrevInAEL, rb->Bot);
AddJoin(Op1, Op2, rb->Top);
}
TEdge *e = lb->NextInAEL;
if (e) {
while (e != rb) {
// nb: For calculating winding counts etc, IntersectEdges() assumes
// that param1 will be to the Right of param2 ABOVE the intersection
// ...
IntersectEdges(rb, e, lb->Curr); // order important here
e = e->NextInAEL;
}
}
}
}
}
//------------------------------------------------------------------------------
void Clipper::DeleteFromSEL(TEdge *e) {
TEdge *SelPrev = e->PrevInSEL;
TEdge *SelNext = e->NextInSEL;
if (!SelPrev && !SelNext && (e != m_SortedEdges))
return; // already deleted
if (SelPrev)
SelPrev->NextInSEL = SelNext;
else
m_SortedEdges = SelNext;
if (SelNext)
SelNext->PrevInSEL = SelPrev;
e->NextInSEL = 0;
e->PrevInSEL = 0;
}
//------------------------------------------------------------------------------
#ifdef use_xyz
void Clipper::SetZ(IntPoint &pt, TEdge &e1, TEdge &e2) {
if (pt.Z != 0 || !m_ZFill)
return;
else if (pt == e1.Bot)
pt.Z = e1.Bot.Z;
else if (pt == e1.Top)
pt.Z = e1.Top.Z;
else if (pt == e2.Bot)
pt.Z = e2.Bot.Z;
else if (pt == e2.Top)
pt.Z = e2.Top.Z;
else
(*m_ZFill)(e1.Bot, e1.Top, e2.Bot, e2.Top, pt);
}
//------------------------------------------------------------------------------
#endif
void Clipper::IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &Pt) {
bool e1Contributing = (e1->OutIdx >= 0);
bool e2Contributing = (e2->OutIdx >= 0);
#ifdef use_xyz
SetZ(Pt, *e1, *e2);
#endif
#ifdef use_lines
// if either edge is on an OPEN path ...
if (e1->WindDelta == 0 || e2->WindDelta == 0) {
// ignore subject-subject open path intersections UNLESS they
// are both open paths, AND they are both 'contributing maximas' ...
if (e1->WindDelta == 0 && e2->WindDelta == 0)
return;
// if intersecting a subj line with a subj poly ...
else if (e1->PolyTyp == e2->PolyTyp && e1->WindDelta != e2->WindDelta &&
m_ClipType == ctUnion) {
if (e1->WindDelta == 0) {
if (e2Contributing) {
AddOutPt(e1, Pt);
if (e1Contributing)
e1->OutIdx = Unassigned;
}
} else {
if (e1Contributing) {
AddOutPt(e2, Pt);
if (e2Contributing)
e2->OutIdx = Unassigned;
}
}
} else if (e1->PolyTyp != e2->PolyTyp) {
// toggle subj open path OutIdx on/off when Abs(clip.WndCnt) == 1 ...
if ((e1->WindDelta == 0) && abs(e2->WindCnt) == 1 &&
(m_ClipType != ctUnion || e2->WindCnt2 == 0)) {
AddOutPt(e1, Pt);
if (e1Contributing)
e1->OutIdx = Unassigned;
} else if ((e2->WindDelta == 0) && (abs(e1->WindCnt) == 1) &&
(m_ClipType != ctUnion || e1->WindCnt2 == 0)) {
AddOutPt(e2, Pt);
if (e2Contributing)
e2->OutIdx = Unassigned;
}
}
return;
}
#endif
// update winding counts...
// assumes that e1 will be to the Right of e2 ABOVE the intersection
if (e1->PolyTyp == e2->PolyTyp) {
if (IsEvenOddFillType(*e1)) {
int oldE1WindCnt = e1->WindCnt;
e1->WindCnt = e2->WindCnt;
e2->WindCnt = oldE1WindCnt;
} else {
if (e1->WindCnt + e2->WindDelta == 0)
e1->WindCnt = -e1->WindCnt;
else
e1->WindCnt += e2->WindDelta;
if (e2->WindCnt - e1->WindDelta == 0)
e2->WindCnt = -e2->WindCnt;
else
e2->WindCnt -= e1->WindDelta;
}
} else {
if (!IsEvenOddFillType(*e2))
e1->WindCnt2 += e2->WindDelta;
else
e1->WindCnt2 = (e1->WindCnt2 == 0) ? 1 : 0;
if (!IsEvenOddFillType(*e1))
e2->WindCnt2 -= e1->WindDelta;
else
e2->WindCnt2 = (e2->WindCnt2 == 0) ? 1 : 0;
}
PolyFillType e1FillType, e2FillType, e1FillType2, e2FillType2;
if (e1->PolyTyp == ptSubject) {
e1FillType = m_SubjFillType;
e1FillType2 = m_ClipFillType;
} else {
e1FillType = m_ClipFillType;
e1FillType2 = m_SubjFillType;
}
if (e2->PolyTyp == ptSubject) {
e2FillType = m_SubjFillType;
e2FillType2 = m_ClipFillType;
} else {
e2FillType = m_ClipFillType;
e2FillType2 = m_SubjFillType;
}
cInt e1Wc, e2Wc;
switch (e1FillType) {
case pftPositive:
e1Wc = e1->WindCnt;
break;
case pftNegative:
e1Wc = -e1->WindCnt;
break;
default:
e1Wc = Abs(e1->WindCnt);
}
switch (e2FillType) {
case pftPositive:
e2Wc = e2->WindCnt;
break;
case pftNegative:
e2Wc = -e2->WindCnt;
break;
default:
e2Wc = Abs(e2->WindCnt);
}
if (e1Contributing && e2Contributing) {
if ((e1Wc != 0 && e1Wc != 1) || (e2Wc != 0 && e2Wc != 1) ||
(e1->PolyTyp != e2->PolyTyp && m_ClipType != ctXor)) {
AddLocalMaxPoly(e1, e2, Pt);
} else {
AddOutPt(e1, Pt);
AddOutPt(e2, Pt);
SwapSides(*e1, *e2);
SwapPolyIndexes(*e1, *e2);
}
} else if (e1Contributing) {
if (e2Wc == 0 || e2Wc == 1) {
AddOutPt(e1, Pt);
SwapSides(*e1, *e2);
SwapPolyIndexes(*e1, *e2);
}
} else if (e2Contributing) {
if (e1Wc == 0 || e1Wc == 1) {
AddOutPt(e2, Pt);
SwapSides(*e1, *e2);
SwapPolyIndexes(*e1, *e2);
}
} else if ((e1Wc == 0 || e1Wc == 1) && (e2Wc == 0 || e2Wc == 1)) {
// neither edge is currently contributing ...
cInt e1Wc2, e2Wc2;
switch (e1FillType2) {
case pftPositive:
e1Wc2 = e1->WindCnt2;
break;
case pftNegative:
e1Wc2 = -e1->WindCnt2;
break;
default:
e1Wc2 = Abs(e1->WindCnt2);
}
switch (e2FillType2) {
case pftPositive:
e2Wc2 = e2->WindCnt2;
break;
case pftNegative:
e2Wc2 = -e2->WindCnt2;
break;
default:
e2Wc2 = Abs(e2->WindCnt2);
}
if (e1->PolyTyp != e2->PolyTyp) {
AddLocalMinPoly(e1, e2, Pt);
} else if (e1Wc == 1 && e2Wc == 1)
switch (m_ClipType) {
case ctIntersection:
if (e1Wc2 > 0 && e2Wc2 > 0)
AddLocalMinPoly(e1, e2, Pt);
break;
case ctUnion:
if (e1Wc2 <= 0 && e2Wc2 <= 0)
AddLocalMinPoly(e1, e2, Pt);
break;
case ctDifference:
if (((e1->PolyTyp == ptClip) && (e1Wc2 > 0) && (e2Wc2 > 0)) ||
((e1->PolyTyp == ptSubject) && (e1Wc2 <= 0) && (e2Wc2 <= 0)))
AddLocalMinPoly(e1, e2, Pt);
break;
case ctXor:
AddLocalMinPoly(e1, e2, Pt);
}
else
SwapSides(*e1, *e2);
}
}
//------------------------------------------------------------------------------
void Clipper::SetHoleState(TEdge *e, OutRec *outrec) {
TEdge *e2 = e->PrevInAEL;
TEdge *eTmp = 0;
while (e2) {
if (e2->OutIdx >= 0 && e2->WindDelta != 0) {
if (!eTmp)
eTmp = e2;
else if (eTmp->OutIdx == e2->OutIdx)
eTmp = 0;
}
e2 = e2->PrevInAEL;
}
if (!eTmp) {
outrec->FirstLeft = 0;
outrec->IsHole = false;
} else {
outrec->FirstLeft = m_PolyOuts[eTmp->OutIdx];
outrec->IsHole = !outrec->FirstLeft->IsHole;
}
}
//------------------------------------------------------------------------------
OutRec *GetLowermostRec(OutRec *outRec1, OutRec *outRec2) {
// work out which polygon fragment has the correct hole state ...
if (!outRec1->BottomPt)
outRec1->BottomPt = GetBottomPt(outRec1->Pts);
if (!outRec2->BottomPt)
outRec2->BottomPt = GetBottomPt(outRec2->Pts);
OutPt *OutPt1 = outRec1->BottomPt;
OutPt *OutPt2 = outRec2->BottomPt;
if (OutPt1->Pt.Y > OutPt2->Pt.Y)
return outRec1;
else if (OutPt1->Pt.Y < OutPt2->Pt.Y)
return outRec2;
else if (OutPt1->Pt.X < OutPt2->Pt.X)
return outRec1;
else if (OutPt1->Pt.X > OutPt2->Pt.X)
return outRec2;
else if (OutPt1->Next == OutPt1)
return outRec2;
else if (OutPt2->Next == OutPt2)
return outRec1;
else if (FirstIsBottomPt(OutPt1, OutPt2))
return outRec1;
else
return outRec2;
}
//------------------------------------------------------------------------------
bool OutRec1RightOfOutRec2(OutRec *outRec1, OutRec *outRec2) {
do {
outRec1 = outRec1->FirstLeft;
if (outRec1 == outRec2)
return true;
} while (outRec1);
return false;
}
//------------------------------------------------------------------------------
OutRec *Clipper::GetOutRec(int Idx) {
OutRec *outrec = m_PolyOuts[Idx];
while (outrec != m_PolyOuts[outrec->Idx])
outrec = m_PolyOuts[outrec->Idx];
return outrec;
}
//------------------------------------------------------------------------------
void Clipper::AppendPolygon(TEdge *e1, TEdge *e2) {
// get the start and ends of both output polygons ...
OutRec *outRec1 = m_PolyOuts[e1->OutIdx];
OutRec *outRec2 = m_PolyOuts[e2->OutIdx];
OutRec *holeStateRec;
if (OutRec1RightOfOutRec2(outRec1, outRec2))
holeStateRec = outRec2;
else if (OutRec1RightOfOutRec2(outRec2, outRec1))
holeStateRec = outRec1;
else
holeStateRec = GetLowermostRec(outRec1, outRec2);
// get the start and ends of both output polygons and
// join e2 poly onto e1 poly and delete pointers to e2 ...
OutPt *p1_lft = outRec1->Pts;
OutPt *p1_rt = p1_lft->Prev;
OutPt *p2_lft = outRec2->Pts;
OutPt *p2_rt = p2_lft->Prev;
// join e2 poly onto e1 poly and delete pointers to e2 ...
if (e1->Side == esLeft) {
if (e2->Side == esLeft) {
// z y x a b c
ReversePolyPtLinks(p2_lft);
p2_lft->Next = p1_lft;
p1_lft->Prev = p2_lft;
p1_rt->Next = p2_rt;
p2_rt->Prev = p1_rt;
outRec1->Pts = p2_rt;
} else {
// x y z a b c
p2_rt->Next = p1_lft;
p1_lft->Prev = p2_rt;
p2_lft->Prev = p1_rt;
p1_rt->Next = p2_lft;
outRec1->Pts = p2_lft;
}
} else {
if (e2->Side == esRight) {
// a b c z y x
ReversePolyPtLinks(p2_lft);
p1_rt->Next = p2_rt;
p2_rt->Prev = p1_rt;
p2_lft->Next = p1_lft;
p1_lft->Prev = p2_lft;
} else {
// a b c x y z
p1_rt->Next = p2_lft;
p2_lft->Prev = p1_rt;
p1_lft->Prev = p2_rt;
p2_rt->Next = p1_lft;
}
}
outRec1->BottomPt = 0;
if (holeStateRec == outRec2) {
if (outRec2->FirstLeft != outRec1)
outRec1->FirstLeft = outRec2->FirstLeft;
outRec1->IsHole = outRec2->IsHole;
}
outRec2->Pts = 0;
outRec2->BottomPt = 0;
outRec2->FirstLeft = outRec1;
int OKIdx = e1->OutIdx;
int ObsoleteIdx = e2->OutIdx;
e1->OutIdx =
Unassigned; // nb: safe because we only get here via AddLocalMaxPoly
e2->OutIdx = Unassigned;
TEdge *e = m_ActiveEdges;
while (e) {
if (e->OutIdx == ObsoleteIdx) {
e->OutIdx = OKIdx;
e->Side = e1->Side;
break;
}
e = e->NextInAEL;
}
outRec2->Idx = outRec1->Idx;
}
//------------------------------------------------------------------------------
OutPt *Clipper::AddOutPt(TEdge *e, const IntPoint &pt) {
if (e->OutIdx < 0) {
OutRec *outRec = CreateOutRec();
outRec->IsOpen = (e->WindDelta == 0);
OutPt *newOp = new OutPt;
outRec->Pts = newOp;
newOp->Idx = outRec->Idx;
newOp->Pt = pt;
newOp->Next = newOp;
newOp->Prev = newOp;
if (!outRec->IsOpen)
SetHoleState(e, outRec);
e->OutIdx = outRec->Idx;
return newOp;
} else {
OutRec *outRec = m_PolyOuts[e->OutIdx];
// OutRec.Pts is the 'Left-most' point & OutRec.Pts.Prev is the 'Right-most'
OutPt *op = outRec->Pts;
bool ToFront = (e->Side == esLeft);
if (ToFront && (pt == op->Pt))
return op;
else if (!ToFront && (pt == op->Prev->Pt))
return op->Prev;
OutPt *newOp = new OutPt;
newOp->Idx = outRec->Idx;
newOp->Pt = pt;
newOp->Next = op;
newOp->Prev = op->Prev;
newOp->Prev->Next = newOp;
op->Prev = newOp;
if (ToFront)
outRec->Pts = newOp;
return newOp;
}
}
//------------------------------------------------------------------------------
OutPt *Clipper::GetLastOutPt(TEdge *e) {
OutRec *outRec = m_PolyOuts[e->OutIdx];
if (e->Side == esLeft)
return outRec->Pts;
else
return outRec->Pts->Prev;
}
//------------------------------------------------------------------------------
void Clipper::ProcessHorizontals() {
TEdge *horzEdge;
while (PopEdgeFromSEL(horzEdge))
ProcessHorizontal(horzEdge);
}
//------------------------------------------------------------------------------
inline bool IsMinima(TEdge *e) {
return e && (e->Prev->NextInLML != e) && (e->Next->NextInLML != e);
}
//------------------------------------------------------------------------------
inline bool IsMaxima(TEdge *e, const cInt Y) {
return e && e->Top.Y == Y && !e->NextInLML;
}
//------------------------------------------------------------------------------
inline bool IsIntermediate(TEdge *e, const cInt Y) {
return e->Top.Y == Y && e->NextInLML;
}
//------------------------------------------------------------------------------
TEdge *GetMaximaPair(TEdge *e) {
if ((e->Next->Top == e->Top) && !e->Next->NextInLML)
return e->Next;
else if ((e->Prev->Top == e->Top) && !e->Prev->NextInLML)
return e->Prev;
else
return 0;
}
//------------------------------------------------------------------------------
TEdge *GetMaximaPairEx(TEdge *e) {
// as GetMaximaPair() but returns 0 if MaxPair isn't in AEL (unless it's
// horizontal)
TEdge *result = GetMaximaPair(e);
if (result &&
(result->OutIdx == Skip ||
(result->NextInAEL == result->PrevInAEL && !IsHorizontal(*result))))
return 0;
return result;
}
//------------------------------------------------------------------------------
void Clipper::SwapPositionsInSEL(TEdge *Edge1, TEdge *Edge2) {
if (!(Edge1->NextInSEL) && !(Edge1->PrevInSEL))
return;
if (!(Edge2->NextInSEL) && !(Edge2->PrevInSEL))
return;
if (Edge1->NextInSEL == Edge2) {
TEdge *Next = Edge2->NextInSEL;
if (Next)
Next->PrevInSEL = Edge1;
TEdge *Prev = Edge1->PrevInSEL;
if (Prev)
Prev->NextInSEL = Edge2;
Edge2->PrevInSEL = Prev;
Edge2->NextInSEL = Edge1;
Edge1->PrevInSEL = Edge2;
Edge1->NextInSEL = Next;
} else if (Edge2->NextInSEL == Edge1) {
TEdge *Next = Edge1->NextInSEL;
if (Next)
Next->PrevInSEL = Edge2;
TEdge *Prev = Edge2->PrevInSEL;
if (Prev)
Prev->NextInSEL = Edge1;
Edge1->PrevInSEL = Prev;
Edge1->NextInSEL = Edge2;
Edge2->PrevInSEL = Edge1;
Edge2->NextInSEL = Next;
} else {
TEdge *Next = Edge1->NextInSEL;
TEdge *Prev = Edge1->PrevInSEL;
Edge1->NextInSEL = Edge2->NextInSEL;
if (Edge1->NextInSEL)
Edge1->NextInSEL->PrevInSEL = Edge1;
Edge1->PrevInSEL = Edge2->PrevInSEL;
if (Edge1->PrevInSEL)
Edge1->PrevInSEL->NextInSEL = Edge1;
Edge2->NextInSEL = Next;
if (Edge2->NextInSEL)
Edge2->NextInSEL->PrevInSEL = Edge2;
Edge2->PrevInSEL = Prev;
if (Edge2->PrevInSEL)
Edge2->PrevInSEL->NextInSEL = Edge2;
}
if (!Edge1->PrevInSEL)
m_SortedEdges = Edge1;
else if (!Edge2->PrevInSEL)
m_SortedEdges = Edge2;
}
//------------------------------------------------------------------------------
TEdge *GetNextInAEL(TEdge *e, Direction dir) {
return dir == dLeftToRight ? e->NextInAEL : e->PrevInAEL;
}
//------------------------------------------------------------------------------
void GetHorzDirection(TEdge &HorzEdge, Direction &Dir, cInt &Left,
cInt &Right) {
if (HorzEdge.Bot.X < HorzEdge.Top.X) {
Left = HorzEdge.Bot.X;
Right = HorzEdge.Top.X;
Dir = dLeftToRight;
} else {
Left = HorzEdge.Top.X;
Right = HorzEdge.Bot.X;
Dir = dRightToLeft;
}
}
//------------------------------------------------------------------------
/*******************************************************************************
* Notes: Horizontal edges (HEs) at scanline intersections (ie at the Top or *
* Bottom of a scanbeam) are processed as if layered. The order in which HEs *
* are processed doesn't matter. HEs intersect with other HE Bot.Xs only [#] *
* (or they could intersect with Top.Xs only, ie EITHER Bot.Xs OR Top.Xs), *
* and with other non-horizontal edges [*]. Once these intersections are *
* processed, intermediate HEs then 'promote' the Edge above (NextInLML) into *
* the AEL. These 'promoted' edges may in turn intersect [%] with other HEs. *
*******************************************************************************/
void Clipper::ProcessHorizontal(TEdge *horzEdge) {
Direction dir;
cInt horzLeft, horzRight;
bool IsOpen = (horzEdge->WindDelta == 0);
GetHorzDirection(*horzEdge, dir, horzLeft, horzRight);
TEdge *eLastHorz = horzEdge, *eMaxPair = 0;
while (eLastHorz->NextInLML && IsHorizontal(*eLastHorz->NextInLML))
eLastHorz = eLastHorz->NextInLML;
if (!eLastHorz->NextInLML)
eMaxPair = GetMaximaPair(eLastHorz);
MaximaList::const_iterator maxIt;
MaximaList::const_reverse_iterator maxRit;
if (m_Maxima.size() > 0) {
// get the first maxima in range (X) ...
if (dir == dLeftToRight) {
maxIt = m_Maxima.begin();
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
maxIt++;
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
maxIt = m_Maxima.end();
} else {
maxRit = m_Maxima.rbegin();
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
maxRit++;
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
maxRit = m_Maxima.rend();
}
}
OutPt *op1 = 0;
for (;;) // loop through consec. horizontal edges
{
bool IsLastHorz = (horzEdge == eLastHorz);
TEdge *e = GetNextInAEL(horzEdge, dir);
while (e) {
// this code block inserts extra coords into horizontal edges (in output
// polygons) whereever maxima touch these horizontal edges. This helps
//'simplifying' polygons (ie if the Simplify property is set).
if (m_Maxima.size() > 0) {
if (dir == dLeftToRight) {
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
if (horzEdge->OutIdx >= 0 && !IsOpen)
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
maxIt++;
}
} else {
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
if (horzEdge->OutIdx >= 0 && !IsOpen)
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
maxRit++;
}
}
};
if ((dir == dLeftToRight && e->Curr.X > horzRight) ||
(dir == dRightToLeft && e->Curr.X < horzLeft))
break;
// Also break if we've got to the end of an intermediate horizontal edge
// ...
// nb: Smaller Dx's are to the right of larger Dx's ABOVE the horizontal.
if (e->Curr.X == horzEdge->Top.X && horzEdge->NextInLML &&
e->Dx < horzEdge->NextInLML->Dx)
break;
if (horzEdge->OutIdx >= 0 && !IsOpen) // note: may be done multiple times
{
#ifdef use_xyz
if (dir == dLeftToRight)
SetZ(e->Curr, *horzEdge, *e);
else
SetZ(e->Curr, *e, *horzEdge);
#endif
op1 = AddOutPt(horzEdge, e->Curr);
TEdge *eNextHorz = m_SortedEdges;
while (eNextHorz) {
if (eNextHorz->OutIdx >= 0 &&
HorzSegmentsOverlap(horzEdge->Bot.X, horzEdge->Top.X,
eNextHorz->Bot.X, eNextHorz->Top.X)) {
OutPt *op2 = GetLastOutPt(eNextHorz);
AddJoin(op2, op1, eNextHorz->Top);
}
eNextHorz = eNextHorz->NextInSEL;
}
AddGhostJoin(op1, horzEdge->Bot);
}
// OK, so far we're still in range of the horizontal Edge but make sure
// we're at the last of consec. horizontals when matching with eMaxPair
if (e == eMaxPair && IsLastHorz) {
if (horzEdge->OutIdx >= 0)
AddLocalMaxPoly(horzEdge, eMaxPair, horzEdge->Top);
DeleteFromAEL(horzEdge);
DeleteFromAEL(eMaxPair);
return;
}
if (dir == dLeftToRight) {
IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y);
IntersectEdges(horzEdge, e, Pt);
} else {
IntPoint Pt = IntPoint(e->Curr.X, horzEdge->Curr.Y);
IntersectEdges(e, horzEdge, Pt);
}
TEdge *eNext = GetNextInAEL(e, dir);
SwapPositionsInAEL(horzEdge, e);
e = eNext;
} // end while(e)
// Break out of loop if HorzEdge.NextInLML is not also horizontal ...
if (!horzEdge->NextInLML || !IsHorizontal(*horzEdge->NextInLML))
break;
UpdateEdgeIntoAEL(horzEdge);
if (horzEdge->OutIdx >= 0)
AddOutPt(horzEdge, horzEdge->Bot);
GetHorzDirection(*horzEdge, dir, horzLeft, horzRight);
} // end for (;;)
if (horzEdge->OutIdx >= 0 && !op1) {
op1 = GetLastOutPt(horzEdge);
TEdge *eNextHorz = m_SortedEdges;
while (eNextHorz) {
if (eNextHorz->OutIdx >= 0 &&
HorzSegmentsOverlap(horzEdge->Bot.X, horzEdge->Top.X,
eNextHorz->Bot.X, eNextHorz->Top.X)) {
OutPt *op2 = GetLastOutPt(eNextHorz);
AddJoin(op2, op1, eNextHorz->Top);
}
eNextHorz = eNextHorz->NextInSEL;
}
AddGhostJoin(op1, horzEdge->Top);
}
if (horzEdge->NextInLML) {
if (horzEdge->OutIdx >= 0) {
op1 = AddOutPt(horzEdge, horzEdge->Top);
UpdateEdgeIntoAEL(horzEdge);
if (horzEdge->WindDelta == 0)
return;
// nb: HorzEdge is no longer horizontal here
TEdge *ePrev = horzEdge->PrevInAEL;
TEdge *eNext = horzEdge->NextInAEL;
if (ePrev && ePrev->Curr.X == horzEdge->Bot.X &&
ePrev->Curr.Y == horzEdge->Bot.Y && ePrev->WindDelta != 0 &&
(ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y &&
SlopesEqual(*horzEdge, *ePrev, m_UseFullRange))) {
OutPt *op2 = AddOutPt(ePrev, horzEdge->Bot);
AddJoin(op1, op2, horzEdge->Top);
} else if (eNext && eNext->Curr.X == horzEdge->Bot.X &&
eNext->Curr.Y == horzEdge->Bot.Y && eNext->WindDelta != 0 &&
eNext->OutIdx >= 0 && eNext->Curr.Y > eNext->Top.Y &&
SlopesEqual(*horzEdge, *eNext, m_UseFullRange)) {
OutPt *op2 = AddOutPt(eNext, horzEdge->Bot);
AddJoin(op1, op2, horzEdge->Top);
}
} else
UpdateEdgeIntoAEL(horzEdge);
} else {
if (horzEdge->OutIdx >= 0)
AddOutPt(horzEdge, horzEdge->Top);
DeleteFromAEL(horzEdge);
}
}
//------------------------------------------------------------------------------
bool Clipper::ProcessIntersections(const cInt topY) {
if (!m_ActiveEdges)
return true;
try {
BuildIntersectList(topY);
size_t IlSize = m_IntersectList.size();
if (IlSize == 0)
return true;
if (IlSize == 1 || FixupIntersectionOrder())
ProcessIntersectList();
else
return false;
} catch (...) {
m_SortedEdges = 0;
DisposeIntersectNodes();
throw clipperException("ProcessIntersections error");
}
m_SortedEdges = 0;
return true;
}
//------------------------------------------------------------------------------
void Clipper::DisposeIntersectNodes() {
for (size_t i = 0; i < m_IntersectList.size(); ++i)
delete m_IntersectList[i];
m_IntersectList.clear();
}
//------------------------------------------------------------------------------
void Clipper::BuildIntersectList(const cInt topY) {
if (!m_ActiveEdges)
return;
// prepare for sorting ...
TEdge *e = m_ActiveEdges;
m_SortedEdges = e;
while (e) {
e->PrevInSEL = e->PrevInAEL;
e->NextInSEL = e->NextInAEL;
e->Curr.X = TopX(*e, topY);
e = e->NextInAEL;
}
// bubblesort ...
bool isModified;
do {
isModified = false;
e = m_SortedEdges;
while (e->NextInSEL) {
TEdge *eNext = e->NextInSEL;
IntPoint Pt;
if (e->Curr.X > eNext->Curr.X) {
IntersectPoint(*e, *eNext, Pt);
if (Pt.Y < topY)
Pt = IntPoint(TopX(*e, topY), topY);
IntersectNode *newNode = new IntersectNode;
newNode->Edge1 = e;
newNode->Edge2 = eNext;
newNode->Pt = Pt;
m_IntersectList.push_back(newNode);
SwapPositionsInSEL(e, eNext);
isModified = true;
} else
e = eNext;
}
if (e->PrevInSEL)
e->PrevInSEL->NextInSEL = 0;
else
break;
} while (isModified);
m_SortedEdges = 0; // important
}
//------------------------------------------------------------------------------
void Clipper::ProcessIntersectList() {
for (size_t i = 0; i < m_IntersectList.size(); ++i) {
IntersectNode *iNode = m_IntersectList[i];
{
IntersectEdges(iNode->Edge1, iNode->Edge2, iNode->Pt);
SwapPositionsInAEL(iNode->Edge1, iNode->Edge2);
}
delete iNode;
}
m_IntersectList.clear();
}
//------------------------------------------------------------------------------
bool IntersectListSort(IntersectNode *node1, IntersectNode *node2) {
return node2->Pt.Y < node1->Pt.Y;
}
//------------------------------------------------------------------------------
inline bool EdgesAdjacent(const IntersectNode &inode) {
return (inode.Edge1->NextInSEL == inode.Edge2) ||
(inode.Edge1->PrevInSEL == inode.Edge2);
}
//------------------------------------------------------------------------------
bool Clipper::FixupIntersectionOrder() {
// pre-condition: intersections are sorted Bottom-most first.
// Now it's crucial that intersections are made only between adjacent edges,
// so to ensure this the order of intersections may need adjusting ...
CopyAELToSEL();
std::sort(m_IntersectList.begin(), m_IntersectList.end(), IntersectListSort);
size_t cnt = m_IntersectList.size();
for (size_t i = 0; i < cnt; ++i) {
if (!EdgesAdjacent(*m_IntersectList[i])) {
size_t j = i + 1;
while (j < cnt && !EdgesAdjacent(*m_IntersectList[j]))
j++;
if (j == cnt)
return false;
std::swap(m_IntersectList[i], m_IntersectList[j]);
}
SwapPositionsInSEL(m_IntersectList[i]->Edge1, m_IntersectList[i]->Edge2);
}
return true;
}
//------------------------------------------------------------------------------
void Clipper::DoMaxima(TEdge *e) {
TEdge *eMaxPair = GetMaximaPairEx(e);
if (!eMaxPair) {
if (e->OutIdx >= 0)
AddOutPt(e, e->Top);
DeleteFromAEL(e);
return;
}
TEdge *eNext = e->NextInAEL;
while (eNext && eNext != eMaxPair) {
IntersectEdges(e, eNext, e->Top);
SwapPositionsInAEL(e, eNext);
eNext = e->NextInAEL;
}
if (e->OutIdx == Unassigned && eMaxPair->OutIdx == Unassigned) {
DeleteFromAEL(e);
DeleteFromAEL(eMaxPair);
} else if (e->OutIdx >= 0 && eMaxPair->OutIdx >= 0) {
if (e->OutIdx >= 0)
AddLocalMaxPoly(e, eMaxPair, e->Top);
DeleteFromAEL(e);
DeleteFromAEL(eMaxPair);
}
#ifdef use_lines
else if (e->WindDelta == 0) {
if (e->OutIdx >= 0) {
AddOutPt(e, e->Top);
e->OutIdx = Unassigned;
}
DeleteFromAEL(e);
if (eMaxPair->OutIdx >= 0) {
AddOutPt(eMaxPair, e->Top);
eMaxPair->OutIdx = Unassigned;
}
DeleteFromAEL(eMaxPair);
}
#endif
else
throw clipperException("DoMaxima error");
}
//------------------------------------------------------------------------------
void Clipper::ProcessEdgesAtTopOfScanbeam(const cInt topY) {
TEdge *e = m_ActiveEdges;
while (e) {
// 1. process maxima, treating them as if they're 'bent' horizontal edges,
// but exclude maxima with horizontal edges. nb: e can't be a horizontal.
bool IsMaximaEdge = IsMaxima(e, topY);
if (IsMaximaEdge) {
TEdge *eMaxPair = GetMaximaPairEx(e);
IsMaximaEdge = (!eMaxPair || !IsHorizontal(*eMaxPair));
}
if (IsMaximaEdge) {
if (m_StrictSimple)
m_Maxima.push_back(e->Top.X);
TEdge *ePrev = e->PrevInAEL;
DoMaxima(e);
if (!ePrev)
e = m_ActiveEdges;
else
e = ePrev->NextInAEL;
} else {
// 2. promote horizontal edges, otherwise update Curr.X and Curr.Y ...
if (IsIntermediate(e, topY) && IsHorizontal(*e->NextInLML)) {
UpdateEdgeIntoAEL(e);
if (e->OutIdx >= 0)
AddOutPt(e, e->Bot);
AddEdgeToSEL(e);
} else {
e->Curr.X = TopX(*e, topY);
e->Curr.Y = topY;
#ifdef use_xyz
e->Curr.Z =
topY == e->Top.Y ? e->Top.Z : (topY == e->Bot.Y ? e->Bot.Z : 0);
#endif
}
// When StrictlySimple and 'e' is being touched by another edge, then
// make sure both edges have a vertex here ...
if (m_StrictSimple) {
TEdge *ePrev = e->PrevInAEL;
if ((e->OutIdx >= 0) && (e->WindDelta != 0) && ePrev &&
(ePrev->OutIdx >= 0) && (ePrev->Curr.X == e->Curr.X) &&
(ePrev->WindDelta != 0)) {
IntPoint pt = e->Curr;
#ifdef use_xyz
SetZ(pt, *ePrev, *e);
#endif
OutPt *op = AddOutPt(ePrev, pt);
OutPt *op2 = AddOutPt(e, pt);
AddJoin(op, op2, pt); // StrictlySimple (type-3) join
}
}
e = e->NextInAEL;
}
}
// 3. Process horizontals at the Top of the scanbeam ...
m_Maxima.sort();
ProcessHorizontals();
m_Maxima.clear();
// 4. Promote intermediate vertices ...
e = m_ActiveEdges;
while (e) {
if (IsIntermediate(e, topY)) {
OutPt *op = 0;
if (e->OutIdx >= 0)
op = AddOutPt(e, e->Top);
UpdateEdgeIntoAEL(e);
// if output polygons share an edge, they'll need joining later ...
TEdge *ePrev = e->PrevInAEL;
TEdge *eNext = e->NextInAEL;
if (ePrev && ePrev->Curr.X == e->Bot.X && ePrev->Curr.Y == e->Bot.Y &&
op && ePrev->OutIdx >= 0 && ePrev->Curr.Y > ePrev->Top.Y &&
SlopesEqual(e->Curr, e->Top, ePrev->Curr, ePrev->Top,
m_UseFullRange) &&
(e->WindDelta != 0) && (ePrev->WindDelta != 0)) {
OutPt *op2 = AddOutPt(ePrev, e->Bot);
AddJoin(op, op2, e->Top);
} else if (eNext && eNext->Curr.X == e->Bot.X &&
eNext->Curr.Y == e->Bot.Y && op && eNext->OutIdx >= 0 &&
eNext->Curr.Y > eNext->Top.Y &&
SlopesEqual(e->Curr, e->Top, eNext->Curr, eNext->Top,
m_UseFullRange) &&
(e->WindDelta != 0) && (eNext->WindDelta != 0)) {
OutPt *op2 = AddOutPt(eNext, e->Bot);
AddJoin(op, op2, e->Top);
}
}
e = e->NextInAEL;
}
}
//------------------------------------------------------------------------------
void Clipper::FixupOutPolyline(OutRec &outrec) {
OutPt *pp = outrec.Pts;
OutPt *lastPP = pp->Prev;
while (pp != lastPP) {
pp = pp->Next;
if (pp->Pt == pp->Prev->Pt) {
if (pp == lastPP)
lastPP = pp->Prev;
OutPt *tmpPP = pp->Prev;
tmpPP->Next = pp->Next;
pp->Next->Prev = tmpPP;
delete pp;
pp = tmpPP;
}
}
if (pp == pp->Prev) {
DisposeOutPts(pp);
outrec.Pts = 0;
return;
}
}
//------------------------------------------------------------------------------
void Clipper::FixupOutPolygon(OutRec &outrec) {
// FixupOutPolygon() - removes duplicate points and simplifies consecutive
// parallel edges by removing the middle vertex.
OutPt *lastOK = 0;
outrec.BottomPt = 0;
OutPt *pp = outrec.Pts;
bool preserveCol = m_PreserveCollinear || m_StrictSimple;
for (;;) {
if (pp->Prev == pp || pp->Prev == pp->Next) {
DisposeOutPts(pp);
outrec.Pts = 0;
return;
}
// test for duplicate points and collinear edges ...
if ((pp->Pt == pp->Next->Pt) || (pp->Pt == pp->Prev->Pt) ||
(SlopesEqual(pp->Prev->Pt, pp->Pt, pp->Next->Pt, m_UseFullRange) &&
(!preserveCol ||
!Pt2IsBetweenPt1AndPt3(pp->Prev->Pt, pp->Pt, pp->Next->Pt)))) {
lastOK = 0;
OutPt *tmp = pp;
pp->Prev->Next = pp->Next;
pp->Next->Prev = pp->Prev;
pp = pp->Prev;
delete tmp;
} else if (pp == lastOK)
break;
else {
if (!lastOK)
lastOK = pp;
pp = pp->Next;
}
}
outrec.Pts = pp;
}
//------------------------------------------------------------------------------
int PointCount(OutPt *Pts) {
if (!Pts)
return 0;
int result = 0;
OutPt *p = Pts;
do {
result++;
p = p->Next;
} while (p != Pts);
return result;
}
//------------------------------------------------------------------------------
void Clipper::BuildResult(Paths &polys) {
polys.reserve(m_PolyOuts.size());
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
if (!m_PolyOuts[i]->Pts)
continue;
Path pg;
OutPt *p = m_PolyOuts[i]->Pts->Prev;
int cnt = PointCount(p);
if (cnt < 2)
continue;
pg.reserve(cnt);
for (int i = 0; i < cnt; ++i) {
pg.push_back(p->Pt);
p = p->Prev;
}
polys.push_back(pg);
}
}
//------------------------------------------------------------------------------
void Clipper::BuildResult2(PolyTree &polytree) {
polytree.Clear();
polytree.AllNodes.reserve(m_PolyOuts.size());
// add each output polygon/contour to polytree ...
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) {
OutRec *outRec = m_PolyOuts[i];
int cnt = PointCount(outRec->Pts);
if ((outRec->IsOpen && cnt < 2) || (!outRec->IsOpen && cnt < 3))
continue;
FixHoleLinkage(*outRec);
PolyNode *pn = new PolyNode();
// nb: polytree takes ownership of all the PolyNodes
polytree.AllNodes.push_back(pn);
outRec->PolyNd = pn;
pn->Parent = 0;
pn->Index = 0;
pn->Contour.reserve(cnt);
OutPt *op = outRec->Pts->Prev;
for (int j = 0; j < cnt; j++) {
pn->Contour.push_back(op->Pt);
op = op->Prev;
}
}
// fixup PolyNode links etc ...
polytree.Childs.reserve(m_PolyOuts.size());
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); i++) {
OutRec *outRec = m_PolyOuts[i];
if (!outRec->PolyNd)
continue;
if (outRec->IsOpen) {
outRec->PolyNd->m_IsOpen = true;
polytree.AddChild(*outRec->PolyNd);
} else if (outRec->FirstLeft && outRec->FirstLeft->PolyNd)
outRec->FirstLeft->PolyNd->AddChild(*outRec->PolyNd);
else
polytree.AddChild(*outRec->PolyNd);
}
}
//------------------------------------------------------------------------------
void SwapIntersectNodes(IntersectNode &int1, IntersectNode &int2) {
// just swap the contents (because fIntersectNodes is a single-linked-list)
IntersectNode inode = int1; // gets a copy of Int1
int1.Edge1 = int2.Edge1;
int1.Edge2 = int2.Edge2;
int1.Pt = int2.Pt;
int2.Edge1 = inode.Edge1;
int2.Edge2 = inode.Edge2;
int2.Pt = inode.Pt;
}
//------------------------------------------------------------------------------
inline bool E2InsertsBeforeE1(TEdge &e1, TEdge &e2) {
if (e2.Curr.X == e1.Curr.X) {
if (e2.Top.Y > e1.Top.Y)
return e2.Top.X < TopX(e1, e2.Top.Y);
else
return e1.Top.X > TopX(e2, e1.Top.Y);
} else
return e2.Curr.X < e1.Curr.X;
}
//------------------------------------------------------------------------------
bool GetOverlap(const cInt a1, const cInt a2, const cInt b1, const cInt b2,
cInt &Left, cInt &Right) {
if (a1 < a2) {
if (b1 < b2) {
Left = std::max(a1, b1);
Right = std::min(a2, b2);
} else {
Left = std::max(a1, b2);
Right = std::min(a2, b1);
}
} else {
if (b1 < b2) {
Left = std::max(a2, b1);
Right = std::min(a1, b2);
} else {
Left = std::max(a2, b2);
Right = std::min(a1, b1);
}
}
return Left < Right;
}
//------------------------------------------------------------------------------
inline void UpdateOutPtIdxs(OutRec &outrec) {
OutPt *op = outrec.Pts;
do {
op->Idx = outrec.Idx;
op = op->Prev;
} while (op != outrec.Pts);
}
//------------------------------------------------------------------------------
void Clipper::InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge) {
if (!m_ActiveEdges) {
edge->PrevInAEL = 0;
edge->NextInAEL = 0;
m_ActiveEdges = edge;
} else if (!startEdge && E2InsertsBeforeE1(*m_ActiveEdges, *edge)) {
edge->PrevInAEL = 0;
edge->NextInAEL = m_ActiveEdges;
m_ActiveEdges->PrevInAEL = edge;
m_ActiveEdges = edge;
} else {
if (!startEdge)
startEdge = m_ActiveEdges;
while (startEdge->NextInAEL &&
!E2InsertsBeforeE1(*startEdge->NextInAEL, *edge))
startEdge = startEdge->NextInAEL;
edge->NextInAEL = startEdge->NextInAEL;
if (startEdge->NextInAEL)
startEdge->NextInAEL->PrevInAEL = edge;
edge->PrevInAEL = startEdge;
startEdge->NextInAEL = edge;
}
}
//----------------------------------------------------------------------
OutPt *DupOutPt(OutPt *outPt, bool InsertAfter) {
OutPt *result = new OutPt;
result->Pt = outPt->Pt;
result->Idx = outPt->Idx;
if (InsertAfter) {
result->Next = outPt->Next;
result->Prev = outPt;
outPt->Next->Prev = result;
outPt->Next = result;
} else {
result->Prev = outPt->Prev;
result->Next = outPt;
outPt->Prev->Next = result;
outPt->Prev = result;
}
return result;
}
//------------------------------------------------------------------------------
bool JoinHorz(OutPt *op1, OutPt *op1b, OutPt *op2, OutPt *op2b,
const IntPoint Pt, bool DiscardLeft) {
Direction Dir1 = (op1->Pt.X > op1b->Pt.X ? dRightToLeft : dLeftToRight);
Direction Dir2 = (op2->Pt.X > op2b->Pt.X ? dRightToLeft : dLeftToRight);
if (Dir1 == Dir2)
return false;
// When DiscardLeft, we want Op1b to be on the Left of Op1, otherwise we
// want Op1b to be on the Right. (And likewise with Op2 and Op2b.)
// So, to facilitate this while inserting Op1b and Op2b ...
// when DiscardLeft, make sure we're AT or RIGHT of Pt before adding Op1b,
// otherwise make sure we're AT or LEFT of Pt. (Likewise with Op2b.)
if (Dir1 == dLeftToRight) {
while (op1->Next->Pt.X <= Pt.X && op1->Next->Pt.X >= op1->Pt.X &&
op1->Next->Pt.Y == Pt.Y)
op1 = op1->Next;
if (DiscardLeft && (op1->Pt.X != Pt.X))
op1 = op1->Next;
op1b = DupOutPt(op1, !DiscardLeft);
if (op1b->Pt != Pt) {
op1 = op1b;
op1->Pt = Pt;
op1b = DupOutPt(op1, !DiscardLeft);
}
} else {
while (op1->Next->Pt.X >= Pt.X && op1->Next->Pt.X <= op1->Pt.X &&
op1->Next->Pt.Y == Pt.Y)
op1 = op1->Next;
if (!DiscardLeft && (op1->Pt.X != Pt.X))
op1 = op1->Next;
op1b = DupOutPt(op1, DiscardLeft);
if (op1b->Pt != Pt) {
op1 = op1b;
op1->Pt = Pt;
op1b = DupOutPt(op1, DiscardLeft);
}
}
if (Dir2 == dLeftToRight) {
while (op2->Next->Pt.X <= Pt.X && op2->Next->Pt.X >= op2->Pt.X &&
op2->Next->Pt.Y == Pt.Y)
op2 = op2->Next;
if (DiscardLeft && (op2->Pt.X != Pt.X))
op2 = op2->Next;
op2b = DupOutPt(op2, !DiscardLeft);
if (op2b->Pt != Pt) {
op2 = op2b;
op2->Pt = Pt;
op2b = DupOutPt(op2, !DiscardLeft);
};
} else {
while (op2->Next->Pt.X >= Pt.X && op2->Next->Pt.X <= op2->Pt.X &&
op2->Next->Pt.Y == Pt.Y)
op2 = op2->Next;
if (!DiscardLeft && (op2->Pt.X != Pt.X))
op2 = op2->Next;
op2b = DupOutPt(op2, DiscardLeft);
if (op2b->Pt != Pt) {
op2 = op2b;
op2->Pt = Pt;
op2b = DupOutPt(op2, DiscardLeft);
};
};
if ((Dir1 == dLeftToRight) == DiscardLeft) {
op1->Prev = op2;
op2->Next = op1;
op1b->Next = op2b;
op2b->Prev = op1b;
} else {
op1->Next = op2;
op2->Prev = op1;
op1b->Prev = op2b;
op2b->Next = op1b;
}
return true;
}
//------------------------------------------------------------------------------
bool Clipper::JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2) {
OutPt *op1 = j->OutPt1, *op1b;
OutPt *op2 = j->OutPt2, *op2b;
// There are 3 kinds of joins for output polygons ...
// 1. Horizontal joins where Join.OutPt1 & Join.OutPt2 are vertices anywhere
// along (horizontal) collinear edges (& Join.OffPt is on the same
// horizontal).
// 2. Non-horizontal joins where Join.OutPt1 & Join.OutPt2 are at the same
// location at the Bottom of the overlapping segment (& Join.OffPt is above).
// 3. StrictSimple joins where edges touch but are not collinear and where
// Join.OutPt1, Join.OutPt2 & Join.OffPt all share the same point.
bool isHorizontal = (j->OutPt1->Pt.Y == j->OffPt.Y);
if (isHorizontal && (j->OffPt == j->OutPt1->Pt) &&
(j->OffPt == j->OutPt2->Pt)) {
// Strictly Simple join ...
if (outRec1 != outRec2)
return false;
op1b = j->OutPt1->Next;
while (op1b != op1 && (op1b->Pt == j->OffPt))
op1b = op1b->Next;
bool reverse1 = (op1b->Pt.Y > j->OffPt.Y);
op2b = j->OutPt2->Next;
while (op2b != op2 && (op2b->Pt == j->OffPt))
op2b = op2b->Next;
bool reverse2 = (op2b->Pt.Y > j->OffPt.Y);
if (reverse1 == reverse2)
return false;
if (reverse1) {
op1b = DupOutPt(op1, false);
op2b = DupOutPt(op2, true);
op1->Prev = op2;
op2->Next = op1;
op1b->Next = op2b;
op2b->Prev = op1b;
j->OutPt1 = op1;
j->OutPt2 = op1b;
return true;
} else {
op1b = DupOutPt(op1, true);
op2b = DupOutPt(op2, false);
op1->Next = op2;
op2->Prev = op1;
op1b->Prev = op2b;
op2b->Next = op1b;
j->OutPt1 = op1;
j->OutPt2 = op1b;
return true;
}
} else if (isHorizontal) {
// treat horizontal joins differently to non-horizontal joins since with
// them we're not yet sure where the overlapping is. OutPt1.Pt & OutPt2.Pt
// may be anywhere along the horizontal edge.
op1b = op1;
while (op1->Prev->Pt.Y == op1->Pt.Y && op1->Prev != op1b &&
op1->Prev != op2)
op1 = op1->Prev;
while (op1b->Next->Pt.Y == op1b->Pt.Y && op1b->Next != op1 &&
op1b->Next != op2)
op1b = op1b->Next;
if (op1b->Next == op1 || op1b->Next == op2)
return false; // a flat 'polygon'
op2b = op2;
while (op2->Prev->Pt.Y == op2->Pt.Y && op2->Prev != op2b &&
op2->Prev != op1b)
op2 = op2->Prev;
while (op2b->Next->Pt.Y == op2b->Pt.Y && op2b->Next != op2 &&
op2b->Next != op1)
op2b = op2b->Next;
if (op2b->Next == op2 || op2b->Next == op1)
return false; // a flat 'polygon'
cInt Left, Right;
// Op1 --> Op1b & Op2 --> Op2b are the extremites of the horizontal edges
if (!GetOverlap(op1->Pt.X, op1b->Pt.X, op2->Pt.X, op2b->Pt.X, Left, Right))
return false;
// DiscardLeftSide: when overlapping edges are joined, a spike will created
// which needs to be cleaned up. However, we don't want Op1 or Op2 caught up
// on the discard Side as either may still be needed for other joins ...
IntPoint Pt;
bool DiscardLeftSide;
if (op1->Pt.X >= Left && op1->Pt.X <= Right) {
Pt = op1->Pt;
DiscardLeftSide = (op1->Pt.X > op1b->Pt.X);
} else if (op2->Pt.X >= Left && op2->Pt.X <= Right) {
Pt = op2->Pt;
DiscardLeftSide = (op2->Pt.X > op2b->Pt.X);
} else if (op1b->Pt.X >= Left && op1b->Pt.X <= Right) {
Pt = op1b->Pt;
DiscardLeftSide = op1b->Pt.X > op1->Pt.X;
} else {
Pt = op2b->Pt;
DiscardLeftSide = (op2b->Pt.X > op2->Pt.X);
}
j->OutPt1 = op1;
j->OutPt2 = op2;
return JoinHorz(op1, op1b, op2, op2b, Pt, DiscardLeftSide);
} else {
// nb: For non-horizontal joins ...
// 1. Jr.OutPt1.Pt.Y == Jr.OutPt2.Pt.Y
// 2. Jr.OutPt1.Pt > Jr.OffPt.Y
// make sure the polygons are correctly oriented ...
op1b = op1->Next;
while ((op1b->Pt == op1->Pt) && (op1b != op1))
op1b = op1b->Next;
bool Reverse1 = ((op1b->Pt.Y > op1->Pt.Y) ||
!SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange));
if (Reverse1) {
op1b = op1->Prev;
while ((op1b->Pt == op1->Pt) && (op1b != op1))
op1b = op1b->Prev;
if ((op1b->Pt.Y > op1->Pt.Y) ||
!SlopesEqual(op1->Pt, op1b->Pt, j->OffPt, m_UseFullRange))
return false;
};
op2b = op2->Next;
while ((op2b->Pt == op2->Pt) && (op2b != op2))
op2b = op2b->Next;
bool Reverse2 = ((op2b->Pt.Y > op2->Pt.Y) ||
!SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange));
if (Reverse2) {
op2b = op2->Prev;
while ((op2b->Pt == op2->Pt) && (op2b != op2))
op2b = op2b->Prev;
if ((op2b->Pt.Y > op2->Pt.Y) ||
!SlopesEqual(op2->Pt, op2b->Pt, j->OffPt, m_UseFullRange))
return false;
}
if ((op1b == op1) || (op2b == op2) || (op1b == op2b) ||
((outRec1 == outRec2) && (Reverse1 == Reverse2)))
return false;
if (Reverse1) {
op1b = DupOutPt(op1, false);
op2b = DupOutPt(op2, true);
op1->Prev = op2;
op2->Next = op1;
op1b->Next = op2b;
op2b->Prev = op1b;
j->OutPt1 = op1;
j->OutPt2 = op1b;
return true;
} else {
op1b = DupOutPt(op1, true);
op2b = DupOutPt(op2, false);
op1->Next = op2;
op2->Prev = op1;
op1b->Prev = op2b;
op2b->Next = op1b;
j->OutPt1 = op1;
j->OutPt2 = op1b;
return true;
}
}
}
//----------------------------------------------------------------------
static OutRec *ParseFirstLeft(OutRec *FirstLeft) {
while (FirstLeft && !FirstLeft->Pts)
FirstLeft = FirstLeft->FirstLeft;
return FirstLeft;
}
//------------------------------------------------------------------------------
void Clipper::FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec) {
// tests if NewOutRec contains the polygon before reassigning FirstLeft
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
OutRec *outRec = m_PolyOuts[i];
OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft);
if (outRec->Pts && firstLeft == OldOutRec) {
if (Poly2ContainsPoly1(outRec->Pts, NewOutRec->Pts))
outRec->FirstLeft = NewOutRec;
}
}
}
//----------------------------------------------------------------------
void Clipper::FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec) {
// A polygon has split into two such that one is now the inner of the other.
// It's possible that these polygons now wrap around other polygons, so check
// every polygon that's also contained by OuterOutRec's FirstLeft container
//(including 0) to see if they've become inner to the new inner polygon ...
OutRec *orfl = OuterOutRec->FirstLeft;
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
OutRec *outRec = m_PolyOuts[i];
if (!outRec->Pts || outRec == OuterOutRec || outRec == InnerOutRec)
continue;
OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft);
if (firstLeft != orfl && firstLeft != InnerOutRec &&
firstLeft != OuterOutRec)
continue;
if (Poly2ContainsPoly1(outRec->Pts, InnerOutRec->Pts))
outRec->FirstLeft = InnerOutRec;
else if (Poly2ContainsPoly1(outRec->Pts, OuterOutRec->Pts))
outRec->FirstLeft = OuterOutRec;
else if (outRec->FirstLeft == InnerOutRec ||
outRec->FirstLeft == OuterOutRec)
outRec->FirstLeft = orfl;
}
}
//----------------------------------------------------------------------
void Clipper::FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec) {
// reassigns FirstLeft WITHOUT testing if NewOutRec contains the polygon
for (PolyOutList::size_type i = 0; i < m_PolyOuts.size(); ++i) {
OutRec *outRec = m_PolyOuts[i];
OutRec *firstLeft = ParseFirstLeft(outRec->FirstLeft);
if (outRec->Pts && firstLeft == OldOutRec)
outRec->FirstLeft = NewOutRec;
}
}
//----------------------------------------------------------------------
void Clipper::JoinCommonEdges() {
for (JoinList::size_type i = 0; i < m_Joins.size(); i++) {
Join *join = m_Joins[i];
OutRec *outRec1 = GetOutRec(join->OutPt1->Idx);
OutRec *outRec2 = GetOutRec(join->OutPt2->Idx);
if (!outRec1->Pts || !outRec2->Pts)
continue;
if (outRec1->IsOpen || outRec2->IsOpen)
continue;
// get the polygon fragment with the correct hole state (FirstLeft)
// before calling JoinPoints() ...
OutRec *holeStateRec;
if (outRec1 == outRec2)
holeStateRec = outRec1;
else if (OutRec1RightOfOutRec2(outRec1, outRec2))
holeStateRec = outRec2;
else if (OutRec1RightOfOutRec2(outRec2, outRec1))
holeStateRec = outRec1;
else
holeStateRec = GetLowermostRec(outRec1, outRec2);
if (!JoinPoints(join, outRec1, outRec2))
continue;
if (outRec1 == outRec2) {
// instead of joining two polygons, we've just created a new one by
// splitting one polygon into two.
outRec1->Pts = join->OutPt1;
outRec1->BottomPt = 0;
outRec2 = CreateOutRec();
outRec2->Pts = join->OutPt2;
// update all OutRec2.Pts Idx's ...
UpdateOutPtIdxs(*outRec2);
if (Poly2ContainsPoly1(outRec2->Pts, outRec1->Pts)) {
// outRec1 contains outRec2 ...
outRec2->IsHole = !outRec1->IsHole;
outRec2->FirstLeft = outRec1;
if (m_UsingPolyTree)
FixupFirstLefts2(outRec2, outRec1);
if ((outRec2->IsHole ^ m_ReverseOutput) == (Area(*outRec2) > 0))
ReversePolyPtLinks(outRec2->Pts);
} else if (Poly2ContainsPoly1(outRec1->Pts, outRec2->Pts)) {
// outRec2 contains outRec1 ...
outRec2->IsHole = outRec1->IsHole;
outRec1->IsHole = !outRec2->IsHole;
outRec2->FirstLeft = outRec1->FirstLeft;
outRec1->FirstLeft = outRec2;
if (m_UsingPolyTree)
FixupFirstLefts2(outRec1, outRec2);
if ((outRec1->IsHole ^ m_ReverseOutput) == (Area(*outRec1) > 0))
ReversePolyPtLinks(outRec1->Pts);
} else {
// the 2 polygons are completely separate ...
outRec2->IsHole = outRec1->IsHole;
outRec2->FirstLeft = outRec1->FirstLeft;
// fixup FirstLeft pointers that may need reassigning to OutRec2
if (m_UsingPolyTree)
FixupFirstLefts1(outRec1, outRec2);
}
} else {
// joined 2 polygons together ...
outRec2->Pts = 0;
outRec2->BottomPt = 0;
outRec2->Idx = outRec1->Idx;
outRec1->IsHole = holeStateRec->IsHole;
if (holeStateRec == outRec2)
outRec1->FirstLeft = outRec2->FirstLeft;
outRec2->FirstLeft = outRec1;
if (m_UsingPolyTree)
FixupFirstLefts3(outRec2, outRec1);
}
}
}
//------------------------------------------------------------------------------
// ClipperOffset support functions ...
//------------------------------------------------------------------------------
DoublePoint GetUnitNormal(const IntPoint &pt1, const IntPoint &pt2) {
if (pt2.X == pt1.X && pt2.Y == pt1.Y)
return DoublePoint(0, 0);
double Dx = (double)(pt2.X - pt1.X);
double dy = (double)(pt2.Y - pt1.Y);
double f = 1 * 1.0 / std::sqrt(Dx * Dx + dy * dy);
Dx *= f;
dy *= f;
return DoublePoint(dy, -Dx);
}
//------------------------------------------------------------------------------
// ClipperOffset class
//------------------------------------------------------------------------------
ClipperOffset::ClipperOffset(double miterLimit, double arcTolerance) {
this->MiterLimit = miterLimit;
this->ArcTolerance = arcTolerance;
m_lowest.X = -1;
}
//------------------------------------------------------------------------------
ClipperOffset::~ClipperOffset() { Clear(); }
//------------------------------------------------------------------------------
void ClipperOffset::Clear() {
for (int i = 0; i < m_polyNodes.ChildCount(); ++i)
delete m_polyNodes.Childs[i];
m_polyNodes.Childs.clear();
m_lowest.X = -1;
}
//------------------------------------------------------------------------------
void ClipperOffset::AddPath(const Path &path, JoinType joinType,
EndType endType) {
int highI = (int)path.size() - 1;
if (highI < 0)
return;
PolyNode *newNode = new PolyNode();
newNode->m_jointype = joinType;
newNode->m_endtype = endType;
// strip duplicate points from path and also get index to the lowest point ...
if (endType == etClosedLine || endType == etClosedPolygon)
while (highI > 0 && path[0] == path[highI])
highI--;
newNode->Contour.reserve(highI + 1);
newNode->Contour.push_back(path[0]);
int j = 0, k = 0;
for (int i = 1; i <= highI; i++)
if (newNode->Contour[j] != path[i]) {
j++;
newNode->Contour.push_back(path[i]);
if (path[i].Y > newNode->Contour[k].Y ||
(path[i].Y == newNode->Contour[k].Y &&
path[i].X < newNode->Contour[k].X))
k = j;
}
if (endType == etClosedPolygon && j < 2) {
delete newNode;
return;
}
m_polyNodes.AddChild(*newNode);
// if this path's lowest pt is lower than all the others then update m_lowest
if (endType != etClosedPolygon)
return;
if (m_lowest.X < 0)
m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k);
else {
IntPoint ip = m_polyNodes.Childs[(int)m_lowest.X]->Contour[(int)m_lowest.Y];
if (newNode->Contour[k].Y > ip.Y ||
(newNode->Contour[k].Y == ip.Y && newNode->Contour[k].X < ip.X))
m_lowest = IntPoint(m_polyNodes.ChildCount() - 1, k);
}
}
//------------------------------------------------------------------------------
void ClipperOffset::AddPaths(const Paths &paths, JoinType joinType,
EndType endType) {
for (Paths::size_type i = 0; i < paths.size(); ++i)
AddPath(paths[i], joinType, endType);
}
//------------------------------------------------------------------------------
void ClipperOffset::FixOrientations() {
// fixup orientations of all closed paths if the orientation of the
// closed path with the lowermost vertex is wrong ...
if (m_lowest.X >= 0 &&
!Orientation(m_polyNodes.Childs[(int)m_lowest.X]->Contour)) {
for (int i = 0; i < m_polyNodes.ChildCount(); ++i) {
PolyNode &node = *m_polyNodes.Childs[i];
if (node.m_endtype == etClosedPolygon ||
(node.m_endtype == etClosedLine && Orientation(node.Contour)))
ReversePath(node.Contour);
}
} else {
for (int i = 0; i < m_polyNodes.ChildCount(); ++i) {
PolyNode &node = *m_polyNodes.Childs[i];
if (node.m_endtype == etClosedLine && !Orientation(node.Contour))
ReversePath(node.Contour);
}
}
}
//------------------------------------------------------------------------------
void ClipperOffset::Execute(Paths &solution, double delta) {
solution.clear();
FixOrientations();
DoOffset(delta);
// now clean up 'corners' ...
Clipper clpr;
clpr.AddPaths(m_destPolys, ptSubject, true);
if (delta > 0) {
clpr.Execute(ctUnion, solution, pftPositive, pftPositive);
} else {
IntRect r = clpr.GetBounds();
Path outer(4);
outer[0] = IntPoint(r.left - 10, r.bottom + 10);
outer[1] = IntPoint(r.right + 10, r.bottom + 10);
outer[2] = IntPoint(r.right + 10, r.top - 10);
outer[3] = IntPoint(r.left - 10, r.top - 10);
clpr.AddPath(outer, ptSubject, true);
clpr.ReverseSolution(true);
clpr.Execute(ctUnion, solution, pftNegative, pftNegative);
if (solution.size() > 0)
solution.erase(solution.begin());
}
}
//------------------------------------------------------------------------------
void ClipperOffset::Execute(PolyTree &solution, double delta) {
solution.Clear();
FixOrientations();
DoOffset(delta);
// now clean up 'corners' ...
Clipper clpr;
clpr.AddPaths(m_destPolys, ptSubject, true);
if (delta > 0) {
clpr.Execute(ctUnion, solution, pftPositive, pftPositive);
} else {
IntRect r = clpr.GetBounds();
Path outer(4);
outer[0] = IntPoint(r.left - 10, r.bottom + 10);
outer[1] = IntPoint(r.right + 10, r.bottom + 10);
outer[2] = IntPoint(r.right + 10, r.top - 10);
outer[3] = IntPoint(r.left - 10, r.top - 10);
clpr.AddPath(outer, ptSubject, true);
clpr.ReverseSolution(true);
clpr.Execute(ctUnion, solution, pftNegative, pftNegative);
// remove the outer PolyNode rectangle ...
if (solution.ChildCount() == 1 && solution.Childs[0]->ChildCount() > 0) {
PolyNode *outerNode = solution.Childs[0];
solution.Childs.reserve(outerNode->ChildCount());
solution.Childs[0] = outerNode->Childs[0];
solution.Childs[0]->Parent = outerNode->Parent;
for (int i = 1; i < outerNode->ChildCount(); ++i)
solution.AddChild(*outerNode->Childs[i]);
} else
solution.Clear();
}
}
//------------------------------------------------------------------------------
void ClipperOffset::DoOffset(double delta) {
m_destPolys.clear();
m_delta = delta;
// if Zero offset, just copy any CLOSED polygons to m_p and return ...
if (NEAR_ZERO(delta)) {
m_destPolys.reserve(m_polyNodes.ChildCount());
for (int i = 0; i < m_polyNodes.ChildCount(); i++) {
PolyNode &node = *m_polyNodes.Childs[i];
if (node.m_endtype == etClosedPolygon)
m_destPolys.push_back(node.Contour);
}
return;
}
// see offset_triginometry3.svg in the documentation folder ...
if (MiterLimit > 2)
m_miterLim = 2 / (MiterLimit * MiterLimit);
else
m_miterLim = 0.5;
double y;
if (ArcTolerance <= 0.0)
y = def_arc_tolerance;
else if (ArcTolerance > std::fabs(delta) * def_arc_tolerance)
y = std::fabs(delta) * def_arc_tolerance;
else
y = ArcTolerance;
// see offset_triginometry2.svg in the documentation folder ...
double steps = pi / std::acos(1 - y / std::fabs(delta));
if (steps > std::fabs(delta) * pi)
steps = std::fabs(delta) * pi; // ie excessive precision check
m_sin = std::sin(two_pi / steps);
m_cos = std::cos(two_pi / steps);
m_StepsPerRad = steps / two_pi;
if (delta < 0.0)
m_sin = -m_sin;
m_destPolys.reserve(m_polyNodes.ChildCount() * 2);
for (int i = 0; i < m_polyNodes.ChildCount(); i++) {
PolyNode &node = *m_polyNodes.Childs[i];
m_srcPoly = node.Contour;
int len = (int)m_srcPoly.size();
if (len == 0 ||
(delta <= 0 && (len < 3 || node.m_endtype != etClosedPolygon)))
continue;
m_destPoly.clear();
if (len == 1) {
if (node.m_jointype == jtRound) {
double X = 1.0, Y = 0.0;
for (cInt j = 1; j <= steps; j++) {
m_destPoly.push_back(IntPoint(Round(m_srcPoly[0].X + X * delta),
Round(m_srcPoly[0].Y + Y * delta)));
double X2 = X;
X = X * m_cos - m_sin * Y;
Y = X2 * m_sin + Y * m_cos;
}
} else {
double X = -1.0, Y = -1.0;
for (int j = 0; j < 4; ++j) {
m_destPoly.push_back(IntPoint(Round(m_srcPoly[0].X + X * delta),
Round(m_srcPoly[0].Y + Y * delta)));
if (X < 0)
X = 1;
else if (Y < 0)
Y = 1;
else
X = -1;
}
}
m_destPolys.push_back(m_destPoly);
continue;
}
// build m_normals ...
m_normals.clear();
m_normals.reserve(len);
for (int j = 0; j < len - 1; ++j)
m_normals.push_back(GetUnitNormal(m_srcPoly[j], m_srcPoly[j + 1]));
if (node.m_endtype == etClosedLine || node.m_endtype == etClosedPolygon)
m_normals.push_back(GetUnitNormal(m_srcPoly[len - 1], m_srcPoly[0]));
else
m_normals.push_back(DoublePoint(m_normals[len - 2]));
if (node.m_endtype == etClosedPolygon) {
int k = len - 1;
for (int j = 0; j < len; ++j)
OffsetPoint(j, k, node.m_jointype);
m_destPolys.push_back(m_destPoly);
} else if (node.m_endtype == etClosedLine) {
int k = len - 1;
for (int j = 0; j < len; ++j)
OffsetPoint(j, k, node.m_jointype);
m_destPolys.push_back(m_destPoly);
m_destPoly.clear();
// re-build m_normals ...
DoublePoint n = m_normals[len - 1];
for (int j = len - 1; j > 0; j--)
m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y);
m_normals[0] = DoublePoint(-n.X, -n.Y);
k = 0;
for (int j = len - 1; j >= 0; j--)
OffsetPoint(j, k, node.m_jointype);
m_destPolys.push_back(m_destPoly);
} else {
int k = 0;
for (int j = 1; j < len - 1; ++j)
OffsetPoint(j, k, node.m_jointype);
IntPoint pt1;
if (node.m_endtype == etOpenButt) {
int j = len - 1;
pt1 = IntPoint((cInt)Round(m_srcPoly[j].X + m_normals[j].X * delta),
(cInt)Round(m_srcPoly[j].Y + m_normals[j].Y * delta));
m_destPoly.push_back(pt1);
pt1 = IntPoint((cInt)Round(m_srcPoly[j].X - m_normals[j].X * delta),
(cInt)Round(m_srcPoly[j].Y - m_normals[j].Y * delta));
m_destPoly.push_back(pt1);
} else {
int j = len - 1;
k = len - 2;
m_sinA = 0;
m_normals[j] = DoublePoint(-m_normals[j].X, -m_normals[j].Y);
if (node.m_endtype == etOpenSquare)
DoSquare(j, k);
else
DoRound(j, k);
}
// re-build m_normals ...
for (int j = len - 1; j > 0; j--)
m_normals[j] = DoublePoint(-m_normals[j - 1].X, -m_normals[j - 1].Y);
m_normals[0] = DoublePoint(-m_normals[1].X, -m_normals[1].Y);
k = len - 1;
for (int j = k - 1; j > 0; --j)
OffsetPoint(j, k, node.m_jointype);
if (node.m_endtype == etOpenButt) {
pt1 = IntPoint((cInt)Round(m_srcPoly[0].X - m_normals[0].X * delta),
(cInt)Round(m_srcPoly[0].Y - m_normals[0].Y * delta));
m_destPoly.push_back(pt1);
pt1 = IntPoint((cInt)Round(m_srcPoly[0].X + m_normals[0].X * delta),
(cInt)Round(m_srcPoly[0].Y + m_normals[0].Y * delta));
m_destPoly.push_back(pt1);
} else {
k = 1;
m_sinA = 0;
if (node.m_endtype == etOpenSquare)
DoSquare(0, 1);
else
DoRound(0, 1);
}
m_destPolys.push_back(m_destPoly);
}
}
}
//------------------------------------------------------------------------------
void ClipperOffset::OffsetPoint(int j, int &k, JoinType jointype) {
// cross product ...
m_sinA = (m_normals[k].X * m_normals[j].Y - m_normals[j].X * m_normals[k].Y);
if (std::fabs(m_sinA * m_delta) < 1.0) {
// dot product ...
double cosA =
(m_normals[k].X * m_normals[j].X + m_normals[j].Y * m_normals[k].Y);
if (cosA > 0) // angle => 0 degrees
{
m_destPoly.push_back(
IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta),
Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta)));
return;
}
// else angle => 180 degrees
} else if (m_sinA > 1.0)
m_sinA = 1.0;
else if (m_sinA < -1.0)
m_sinA = -1.0;
if (m_sinA * m_delta < 0) {
m_destPoly.push_back(
IntPoint(Round(m_srcPoly[j].X + m_normals[k].X * m_delta),
Round(m_srcPoly[j].Y + m_normals[k].Y * m_delta)));
m_destPoly.push_back(m_srcPoly[j]);
m_destPoly.push_back(
IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta),
Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta)));
} else
switch (jointype) {
case jtMiter: {
double r = 1 + (m_normals[j].X * m_normals[k].X +
m_normals[j].Y * m_normals[k].Y);
if (r >= m_miterLim)
DoMiter(j, k, r);
else
DoSquare(j, k);
break;
}
case jtSquare:
DoSquare(j, k);
break;
case jtRound:
DoRound(j, k);
break;
}
k = j;
}
//------------------------------------------------------------------------------
void ClipperOffset::DoSquare(int j, int k) {
double dx = std::tan(std::atan2(m_sinA, m_normals[k].X * m_normals[j].X +
m_normals[k].Y * m_normals[j].Y) /
4);
m_destPoly.push_back(IntPoint(
Round(m_srcPoly[j].X + m_delta * (m_normals[k].X - m_normals[k].Y * dx)),
Round(m_srcPoly[j].Y +
m_delta * (m_normals[k].Y + m_normals[k].X * dx))));
m_destPoly.push_back(IntPoint(
Round(m_srcPoly[j].X + m_delta * (m_normals[j].X + m_normals[j].Y * dx)),
Round(m_srcPoly[j].Y +
m_delta * (m_normals[j].Y - m_normals[j].X * dx))));
}
//------------------------------------------------------------------------------
void ClipperOffset::DoMiter(int j, int k, double r) {
double q = m_delta / r;
m_destPoly.push_back(
IntPoint(Round(m_srcPoly[j].X + (m_normals[k].X + m_normals[j].X) * q),
Round(m_srcPoly[j].Y + (m_normals[k].Y + m_normals[j].Y) * q)));
}
//------------------------------------------------------------------------------
void ClipperOffset::DoRound(int j, int k) {
double a = std::atan2(m_sinA, m_normals[k].X * m_normals[j].X +
m_normals[k].Y * m_normals[j].Y);
int steps = std::max((int)Round(m_StepsPerRad * std::fabs(a)), 1);
double X = m_normals[k].X, Y = m_normals[k].Y, X2;
for (int i = 0; i < steps; ++i) {
m_destPoly.push_back(IntPoint(Round(m_srcPoly[j].X + X * m_delta),
Round(m_srcPoly[j].Y + Y * m_delta)));
X2 = X;
X = X * m_cos - m_sin * Y;
Y = X2 * m_sin + Y * m_cos;
}
m_destPoly.push_back(
IntPoint(Round(m_srcPoly[j].X + m_normals[j].X * m_delta),
Round(m_srcPoly[j].Y + m_normals[j].Y * m_delta)));
}
//------------------------------------------------------------------------------
// Miscellaneous public functions
//------------------------------------------------------------------------------
void Clipper::DoSimplePolygons() {
PolyOutList::size_type i = 0;
while (i < m_PolyOuts.size()) {
OutRec *outrec = m_PolyOuts[i++];
OutPt *op = outrec->Pts;
if (!op || outrec->IsOpen)
continue;
do // for each Pt in Polygon until duplicate found do ...
{
OutPt *op2 = op->Next;
while (op2 != outrec->Pts) {
if ((op->Pt == op2->Pt) && op2->Next != op && op2->Prev != op) {
// split the polygon into two ...
OutPt *op3 = op->Prev;
OutPt *op4 = op2->Prev;
op->Prev = op4;
op4->Next = op;
op2->Prev = op3;
op3->Next = op2;
outrec->Pts = op;
OutRec *outrec2 = CreateOutRec();
outrec2->Pts = op2;
UpdateOutPtIdxs(*outrec2);
if (Poly2ContainsPoly1(outrec2->Pts, outrec->Pts)) {
// OutRec2 is contained by OutRec1 ...
outrec2->IsHole = !outrec->IsHole;
outrec2->FirstLeft = outrec;
if (m_UsingPolyTree)
FixupFirstLefts2(outrec2, outrec);
} else if (Poly2ContainsPoly1(outrec->Pts, outrec2->Pts)) {
// OutRec1 is contained by OutRec2 ...
outrec2->IsHole = outrec->IsHole;
outrec->IsHole = !outrec2->IsHole;
outrec2->FirstLeft = outrec->FirstLeft;
outrec->FirstLeft = outrec2;
if (m_UsingPolyTree)
FixupFirstLefts2(outrec, outrec2);
} else {
// the 2 polygons are separate ...
outrec2->IsHole = outrec->IsHole;
outrec2->FirstLeft = outrec->FirstLeft;
if (m_UsingPolyTree)
FixupFirstLefts1(outrec, outrec2);
}
op2 = op; // ie get ready for the Next iteration
}
op2 = op2->Next;
}
op = op->Next;
} while (op != outrec->Pts);
}
}
//------------------------------------------------------------------------------
void ReversePath(Path &p) { std::reverse(p.begin(), p.end()); }
//------------------------------------------------------------------------------
void ReversePaths(Paths &p) {
for (Paths::size_type i = 0; i < p.size(); ++i)
ReversePath(p[i]);
}
//------------------------------------------------------------------------------
void SimplifyPolygon(const Path &in_poly, Paths &out_polys,
PolyFillType fillType) {
Clipper c;
c.StrictlySimple(true);
c.AddPath(in_poly, ptSubject, true);
c.Execute(ctUnion, out_polys, fillType, fillType);
}
//------------------------------------------------------------------------------
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys,
PolyFillType fillType) {
Clipper c;
c.StrictlySimple(true);
c.AddPaths(in_polys, ptSubject, true);
c.Execute(ctUnion, out_polys, fillType, fillType);
}
//------------------------------------------------------------------------------
void SimplifyPolygons(Paths &polys, PolyFillType fillType) {
SimplifyPolygons(polys, polys, fillType);
}
//------------------------------------------------------------------------------
inline double DistanceSqrd(const IntPoint &pt1, const IntPoint &pt2) {
double Dx = ((double)pt1.X - pt2.X);
double dy = ((double)pt1.Y - pt2.Y);
return (Dx * Dx + dy * dy);
}
//------------------------------------------------------------------------------
double DistanceFromLineSqrd(const IntPoint &pt, const IntPoint &ln1,
const IntPoint &ln2) {
// The equation of a line in general form (Ax + By + C = 0)
// given 2 points (x�,y�) & (x�,y�) is ...
//(y� - y�)x + (x� - x�)y + (y� - y�)x� - (x� - x�)y� = 0
// A = (y� - y�); B = (x� - x�); C = (y� - y�)x� - (x� - x�)y�
// perpendicular distance of point (x�,y�) = (Ax� + By� + C)/Sqrt(A� + B�)
// see http://en.wikipedia.org/wiki/Perpendicular_distance
double A = double(ln1.Y - ln2.Y);
double B = double(ln2.X - ln1.X);
double C = A * ln1.X + B * ln1.Y;
C = A * pt.X + B * pt.Y - C;
return (C * C) / (A * A + B * B);
}
//---------------------------------------------------------------------------
bool SlopesNearCollinear(const IntPoint &pt1, const IntPoint &pt2,
const IntPoint &pt3, double distSqrd) {
// this function is more accurate when the point that's geometrically
// between the other 2 points is the one that's tested for distance.
// ie makes it more likely to pick up 'spikes' ...
if (Abs(pt1.X - pt2.X) > Abs(pt1.Y - pt2.Y)) {
if ((pt1.X > pt2.X) == (pt1.X < pt3.X))
return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd;
else if ((pt2.X > pt1.X) == (pt2.X < pt3.X))
return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd;
else
return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd;
} else {
if ((pt1.Y > pt2.Y) == (pt1.Y < pt3.Y))
return DistanceFromLineSqrd(pt1, pt2, pt3) < distSqrd;
else if ((pt2.Y > pt1.Y) == (pt2.Y < pt3.Y))
return DistanceFromLineSqrd(pt2, pt1, pt3) < distSqrd;
else
return DistanceFromLineSqrd(pt3, pt1, pt2) < distSqrd;
}
}
//------------------------------------------------------------------------------
bool PointsAreClose(IntPoint pt1, IntPoint pt2, double distSqrd) {
double Dx = (double)pt1.X - pt2.X;
double dy = (double)pt1.Y - pt2.Y;
return ((Dx * Dx) + (dy * dy) <= distSqrd);
}
//------------------------------------------------------------------------------
OutPt *ExcludeOp(OutPt *op) {
OutPt *result = op->Prev;
result->Next = op->Next;
op->Next->Prev = result;
result->Idx = 0;
return result;
}
//------------------------------------------------------------------------------
void CleanPolygon(const Path &in_poly, Path &out_poly, double distance) {
// distance = proximity in units/pixels below which vertices
// will be stripped. Default ~= sqrt(2).
size_t size = in_poly.size();
if (size == 0) {
out_poly.clear();
return;
}
OutPt *outPts = new OutPt[size];
for (size_t i = 0; i < size; ++i) {
outPts[i].Pt = in_poly[i];
outPts[i].Next = &outPts[(i + 1) % size];
outPts[i].Next->Prev = &outPts[i];
outPts[i].Idx = 0;
}
double distSqrd = distance * distance;
OutPt *op = &outPts[0];
while (op->Idx == 0 && op->Next != op->Prev) {
if (PointsAreClose(op->Pt, op->Prev->Pt, distSqrd)) {
op = ExcludeOp(op);
size--;
} else if (PointsAreClose(op->Prev->Pt, op->Next->Pt, distSqrd)) {
ExcludeOp(op->Next);
op = ExcludeOp(op);
size -= 2;
} else if (SlopesNearCollinear(op->Prev->Pt, op->Pt, op->Next->Pt,
distSqrd)) {
op = ExcludeOp(op);
size--;
} else {
op->Idx = 1;
op = op->Next;
}
}
if (size < 3)
size = 0;
out_poly.resize(size);
for (size_t i = 0; i < size; ++i) {
out_poly[i] = op->Pt;
op = op->Next;
}
delete[] outPts;
}
//------------------------------------------------------------------------------
void CleanPolygon(Path &poly, double distance) {
CleanPolygon(poly, poly, distance);
}
//------------------------------------------------------------------------------
void CleanPolygons(const Paths &in_polys, Paths &out_polys, double distance) {
out_polys.resize(in_polys.size());
for (Paths::size_type i = 0; i < in_polys.size(); ++i)
CleanPolygon(in_polys[i], out_polys[i], distance);
}
//------------------------------------------------------------------------------
void CleanPolygons(Paths &polys, double distance) {
CleanPolygons(polys, polys, distance);
}
//------------------------------------------------------------------------------
void Minkowski(const Path &poly, const Path &path, Paths &solution, bool isSum,
bool isClosed) {
int delta = (isClosed ? 1 : 0);
size_t polyCnt = poly.size();
size_t pathCnt = path.size();
Paths pp;
pp.reserve(pathCnt);
if (isSum)
for (size_t i = 0; i < pathCnt; ++i) {
Path p;
p.reserve(polyCnt);
for (size_t j = 0; j < poly.size(); ++j)
p.push_back(IntPoint(path[i].X + poly[j].X, path[i].Y + poly[j].Y));
pp.push_back(p);
}
else
for (size_t i = 0; i < pathCnt; ++i) {
Path p;
p.reserve(polyCnt);
for (size_t j = 0; j < poly.size(); ++j)
p.push_back(IntPoint(path[i].X - poly[j].X, path[i].Y - poly[j].Y));
pp.push_back(p);
}
solution.clear();
solution.reserve((pathCnt + delta) * (polyCnt + 1));
for (size_t i = 0; i < pathCnt - 1 + delta; ++i)
for (size_t j = 0; j < polyCnt; ++j) {
Path quad;
quad.reserve(4);
quad.push_back(pp[i % pathCnt][j % polyCnt]);
quad.push_back(pp[(i + 1) % pathCnt][j % polyCnt]);
quad.push_back(pp[(i + 1) % pathCnt][(j + 1) % polyCnt]);
quad.push_back(pp[i % pathCnt][(j + 1) % polyCnt]);
if (!Orientation(quad))
ReversePath(quad);
solution.push_back(quad);
}
}
//------------------------------------------------------------------------------
void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution,
bool pathIsClosed) {
Minkowski(pattern, path, solution, true, pathIsClosed);
Clipper c;
c.AddPaths(solution, ptSubject, true);
c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
}
//------------------------------------------------------------------------------
void TranslatePath(const Path &input, Path &output, const IntPoint delta) {
// precondition: input != output
output.resize(input.size());
for (size_t i = 0; i < input.size(); ++i)
output[i] = IntPoint(input[i].X + delta.X, input[i].Y + delta.Y);
}
//------------------------------------------------------------------------------
void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution,
bool pathIsClosed) {
Clipper c;
for (size_t i = 0; i < paths.size(); ++i) {
Paths tmp;
Minkowski(pattern, paths[i], tmp, true, pathIsClosed);
c.AddPaths(tmp, ptSubject, true);
if (pathIsClosed) {
Path tmp2;
TranslatePath(paths[i], tmp2, pattern[0]);
c.AddPath(tmp2, ptClip, true);
}
}
c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
}
//------------------------------------------------------------------------------
void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution) {
Minkowski(poly1, poly2, solution, false, true);
Clipper c;
c.AddPaths(solution, ptSubject, true);
c.Execute(ctUnion, solution, pftNonZero, pftNonZero);
}
//------------------------------------------------------------------------------
enum NodeType { ntAny, ntOpen, ntClosed };
void AddPolyNodeToPaths(const PolyNode &polynode, NodeType nodetype,
Paths &paths) {
bool match = true;
if (nodetype == ntClosed)
match = !polynode.IsOpen();
else if (nodetype == ntOpen)
return;
if (!polynode.Contour.empty() && match)
paths.push_back(polynode.Contour);
for (int i = 0; i < polynode.ChildCount(); ++i)
AddPolyNodeToPaths(*polynode.Childs[i], nodetype, paths);
}
//------------------------------------------------------------------------------
void PolyTreeToPaths(const PolyTree &polytree, Paths &paths) {
paths.resize(0);
paths.reserve(polytree.Total());
AddPolyNodeToPaths(polytree, ntAny, paths);
}
//------------------------------------------------------------------------------
void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths) {
paths.resize(0);
paths.reserve(polytree.Total());
AddPolyNodeToPaths(polytree, ntClosed, paths);
}
//------------------------------------------------------------------------------
void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths) {
paths.resize(0);
paths.reserve(polytree.Total());
// Open paths are top level only, so ...
for (int i = 0; i < polytree.ChildCount(); ++i)
if (polytree.Childs[i]->IsOpen())
paths.push_back(polytree.Childs[i]->Contour);
}
//------------------------------------------------------------------------------
std::ostream &operator<<(std::ostream &s, const IntPoint &p) {
s << "(" << p.X << "," << p.Y << ")";
return s;
}
//------------------------------------------------------------------------------
std::ostream &operator<<(std::ostream &s, const Path &p) {
if (p.empty())
return s;
Path::size_type last = p.size() - 1;
for (Path::size_type i = 0; i < last; i++)
s << "(" << p[i].X << "," << p[i].Y << "), ";
s << "(" << p[last].X << "," << p[last].Y << ")\n";
return s;
}
//------------------------------------------------------------------------------
std::ostream &operator<<(std::ostream &s, const Paths &p) {
for (Paths::size_type i = 0; i < p.size(); i++)
s << p[i];
s << "\n";
return s;
}
//------------------------------------------------------------------------------
} // ClipperLib namespace
/*******************************************************************************
* *
* Author : Angus Johnson *
* Version : 6.4.2 *
* Date : 27 February 2017 *
* Website : http://www.angusj.com *
* Copyright : Angus Johnson 2010-2017 *
* *
* License: *
* Use, modification & distribution is subject to Boost Software License Ver 1. *
* http://www.boost.org/LICENSE_1_0.txt *
* *
* Attributions: *
* The code in this library is an extension of Bala Vatti's clipping algorithm: *
* "A generic solution to polygon clipping" *
* Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. *
* http://portal.acm.org/citation.cfm?id=129906 *
* *
* Computer graphics and geometric modeling: implementation and algorithms *
* By Max K. Agoston *
* Springer; 1 edition (January 4, 2005) *
* http://books.google.com/books?q=vatti+clipping+agoston *
* *
* See also: *
* "Polygon Offsetting by Computing Winding Numbers" *
* Paper no. DETC2005-85513 pp. 565-575 *
* ASME 2005 International Design Engineering Technical Conferences *
* and Computers and Information in Engineering Conference (IDETC/CIE2005) *
* September 24-28, 2005 , Long Beach, California, USA *
* http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf *
* *
*******************************************************************************/
#ifndef clipper_hpp
#define clipper_hpp
#define CLIPPER_VERSION "6.4.2"
// use_int32: When enabled 32bit ints are used instead of 64bit ints. This
// improve performance but coordinate values are limited to the range +/- 46340
//#define use_int32
// use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance.
//#define use_xyz
// use_lines: Enables line clipping. Adds a very minor cost to performance.
#define use_lines
// use_deprecated: Enables temporary support for the obsolete functions
//#define use_deprecated
#include <cstdlib>
#include <cstring>
#include <functional>
#include <list>
#include <ostream>
#include <queue>
#include <set>
#include <stdexcept>
#include <vector>
namespace ClipperLib {
enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor };
enum PolyType { ptSubject, ptClip };
// By far the most widely used winding rules for polygon filling are
// EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32)
// Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL)
// see http://glprogramming.com/red/chapter11.html
enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative };
#ifdef use_int32
typedef int cInt;
static cInt const loRange = 0x7FFF;
static cInt const hiRange = 0x7FFF;
#else
typedef signed long long cInt;
static cInt const loRange = 0x3FFFFFFF;
static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL;
typedef signed long long long64; // used by Int128 class
typedef unsigned long long ulong64;
#endif
struct IntPoint {
cInt X;
cInt Y;
#ifdef use_xyz
cInt Z;
IntPoint(cInt x = 0, cInt y = 0, cInt z = 0) : X(x), Y(y), Z(z){};
#else
IntPoint(cInt x = 0, cInt y = 0) : X(x), Y(y){};
#endif
friend inline bool operator==(const IntPoint &a, const IntPoint &b) {
return a.X == b.X && a.Y == b.Y;
}
friend inline bool operator!=(const IntPoint &a, const IntPoint &b) {
return a.X != b.X || a.Y != b.Y;
}
};
//------------------------------------------------------------------------------
typedef std::vector<IntPoint> Path;
typedef std::vector<Path> Paths;
inline Path &operator<<(Path &poly, const IntPoint &p) {
poly.push_back(p);
return poly;
}
inline Paths &operator<<(Paths &polys, const Path &p) {
polys.push_back(p);
return polys;
}
std::ostream &operator<<(std::ostream &s, const IntPoint &p);
std::ostream &operator<<(std::ostream &s, const Path &p);
std::ostream &operator<<(std::ostream &s, const Paths &p);
struct DoublePoint {
double X;
double Y;
DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {}
DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {}
};
//------------------------------------------------------------------------------
#ifdef use_xyz
typedef void (*ZFillCallback)(IntPoint &e1bot, IntPoint &e1top, IntPoint &e2bot,
IntPoint &e2top, IntPoint &pt);
#endif
enum InitOptions {
ioReverseSolution = 1,
ioStrictlySimple = 2,
ioPreserveCollinear = 4
};
enum JoinType { jtSquare, jtRound, jtMiter };
enum EndType {
etClosedPolygon,
etClosedLine,
etOpenButt,
etOpenSquare,
etOpenRound
};
class PolyNode;
typedef std::vector<PolyNode *> PolyNodes;
class PolyNode {
public:
PolyNode();
virtual ~PolyNode(){};
Path Contour;
PolyNodes Childs;
PolyNode *Parent;
PolyNode *GetNext() const;
bool IsHole() const;
bool IsOpen() const;
int ChildCount() const;
private:
// PolyNode& operator =(PolyNode& other);
unsigned Index; // node index in Parent.Childs
bool m_IsOpen;
JoinType m_jointype;
EndType m_endtype;
PolyNode *GetNextSiblingUp() const;
void AddChild(PolyNode &child);
friend class Clipper; // to access Index
friend class ClipperOffset;
};
class PolyTree : public PolyNode {
public:
~PolyTree() { Clear(); };
PolyNode *GetFirst() const;
void Clear();
int Total() const;
private:
// PolyTree& operator =(PolyTree& other);
PolyNodes AllNodes;
friend class Clipper; // to access AllNodes
};
bool Orientation(const Path &poly);
double Area(const Path &poly);
int PointInPolygon(const IntPoint &pt, const Path &path);
void SimplifyPolygon(const Path &in_poly, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(const Paths &in_polys, Paths &out_polys,
PolyFillType fillType = pftEvenOdd);
void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd);
void CleanPolygon(const Path &in_poly, Path &out_poly, double distance = 1.415);
void CleanPolygon(Path &poly, double distance = 1.415);
void CleanPolygons(const Paths &in_polys, Paths &out_polys,
double distance = 1.415);
void CleanPolygons(Paths &polys, double distance = 1.415);
void MinkowskiSum(const Path &pattern, const Path &path, Paths &solution,
bool pathIsClosed);
void MinkowskiSum(const Path &pattern, const Paths &paths, Paths &solution,
bool pathIsClosed);
void MinkowskiDiff(const Path &poly1, const Path &poly2, Paths &solution);
void PolyTreeToPaths(const PolyTree &polytree, Paths &paths);
void ClosedPathsFromPolyTree(const PolyTree &polytree, Paths &paths);
void OpenPathsFromPolyTree(PolyTree &polytree, Paths &paths);
void ReversePath(Path &p);
void ReversePaths(Paths &p);
struct IntRect {
cInt left;
cInt top;
cInt right;
cInt bottom;
};
// enums that are used internally ...
enum EdgeSide { esLeft = 1, esRight = 2 };
// forward declarations (for stuff used internally) ...
struct TEdge;
struct IntersectNode;
struct LocalMinimum;
struct OutPt;
struct OutRec;
struct Join;
typedef std::vector<OutRec *> PolyOutList;
typedef std::vector<TEdge *> EdgeList;
typedef std::vector<Join *> JoinList;
typedef std::vector<IntersectNode *> IntersectList;
//------------------------------------------------------------------------------
// ClipperBase is the ancestor to the Clipper class. It should not be
// instantiated directly. This class simply abstracts the conversion of sets of
// polygon coordinates into edge objects that are stored in a LocalMinima list.
class ClipperBase {
public:
ClipperBase();
virtual ~ClipperBase();
virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed);
bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed);
virtual void Clear();
IntRect GetBounds();
bool PreserveCollinear() { return m_PreserveCollinear; };
void PreserveCollinear(bool value) { m_PreserveCollinear = value; };
protected:
void DisposeLocalMinimaList();
TEdge *AddBoundsToLML(TEdge *e, bool IsClosed);
virtual void Reset();
TEdge *ProcessBound(TEdge *E, bool IsClockwise);
void InsertScanbeam(const cInt Y);
bool PopScanbeam(cInt &Y);
bool LocalMinimaPending();
bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin);
OutRec *CreateOutRec();
void DisposeAllOutRecs();
void DisposeOutRec(PolyOutList::size_type index);
void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2);
void DeleteFromAEL(TEdge *e);
void UpdateEdgeIntoAEL(TEdge *&e);
typedef std::vector<LocalMinimum> MinimaList;
MinimaList::iterator m_CurrentLM;
MinimaList m_MinimaList;
bool m_UseFullRange;
EdgeList m_edges;
bool m_PreserveCollinear;
bool m_HasOpenPaths;
PolyOutList m_PolyOuts;
TEdge *m_ActiveEdges;
typedef std::priority_queue<cInt> ScanbeamList;
ScanbeamList m_Scanbeam;
};
//------------------------------------------------------------------------------
class Clipper : public virtual ClipperBase {
public:
Clipper(int initOptions = 0);
bool Execute(ClipType clipType, Paths &solution,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, Paths &solution, PolyFillType subjFillType,
PolyFillType clipFillType);
bool Execute(ClipType clipType, PolyTree &polytree,
PolyFillType fillType = pftEvenOdd);
bool Execute(ClipType clipType, PolyTree &polytree, PolyFillType subjFillType,
PolyFillType clipFillType);
bool ReverseSolution() { return m_ReverseOutput; };
void ReverseSolution(bool value) { m_ReverseOutput = value; };
bool StrictlySimple() { return m_StrictSimple; };
void StrictlySimple(bool value) { m_StrictSimple = value; };
// set the callback function for z value filling on intersections (otherwise Z
// is 0)
#ifdef use_xyz
void ZFillFunction(ZFillCallback zFillFunc);
#endif
protected:
virtual bool ExecuteInternal();
private:
JoinList m_Joins;
JoinList m_GhostJoins;
IntersectList m_IntersectList;
ClipType m_ClipType;
typedef std::list<cInt> MaximaList;
MaximaList m_Maxima;
TEdge *m_SortedEdges;
bool m_ExecuteLocked;
PolyFillType m_ClipFillType;
PolyFillType m_SubjFillType;
bool m_ReverseOutput;
bool m_UsingPolyTree;
bool m_StrictSimple;
#ifdef use_xyz
ZFillCallback m_ZFill; // custom callback
#endif
void SetWindingCount(TEdge &edge);
bool IsEvenOddFillType(const TEdge &edge) const;
bool IsEvenOddAltFillType(const TEdge &edge) const;
void InsertLocalMinimaIntoAEL(const cInt botY);
void InsertEdgeIntoAEL(TEdge *edge, TEdge *startEdge);
void AddEdgeToSEL(TEdge *edge);
bool PopEdgeFromSEL(TEdge *&edge);
void CopyAELToSEL();
void DeleteFromSEL(TEdge *e);
void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2);
bool IsContributing(const TEdge &edge) const;
bool IsTopHorz(const cInt XPos);
void DoMaxima(TEdge *e);
void ProcessHorizontals();
void ProcessHorizontal(TEdge *horzEdge);
void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutPt *AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt);
OutRec *GetOutRec(int idx);
void AppendPolygon(TEdge *e1, TEdge *e2);
void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt);
OutPt *AddOutPt(TEdge *e, const IntPoint &pt);
OutPt *GetLastOutPt(TEdge *e);
bool ProcessIntersections(const cInt topY);
void BuildIntersectList(const cInt topY);
void ProcessIntersectList();
void ProcessEdgesAtTopOfScanbeam(const cInt topY);
void BuildResult(Paths &polys);
void BuildResult2(PolyTree &polytree);
void SetHoleState(TEdge *e, OutRec *outrec);
void DisposeIntersectNodes();
bool FixupIntersectionOrder();
void FixupOutPolygon(OutRec &outrec);
void FixupOutPolyline(OutRec &outrec);
bool IsHole(TEdge *e);
bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl);
void FixHoleLinkage(OutRec &outrec);
void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt);
void ClearJoins();
void ClearGhostJoins();
void AddGhostJoin(OutPt *op, const IntPoint offPt);
bool JoinPoints(Join *j, OutRec *outRec1, OutRec *outRec2);
void JoinCommonEdges();
void DoSimplePolygons();
void FixupFirstLefts1(OutRec *OldOutRec, OutRec *NewOutRec);
void FixupFirstLefts2(OutRec *InnerOutRec, OutRec *OuterOutRec);
void FixupFirstLefts3(OutRec *OldOutRec, OutRec *NewOutRec);
#ifdef use_xyz
void SetZ(IntPoint &pt, TEdge &e1, TEdge &e2);
#endif
};
//------------------------------------------------------------------------------
class ClipperOffset {
public:
ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25);
~ClipperOffset();
void AddPath(const Path &path, JoinType joinType, EndType endType);
void AddPaths(const Paths &paths, JoinType joinType, EndType endType);
void Execute(Paths &solution, double delta);
void Execute(PolyTree &solution, double delta);
void Clear();
double MiterLimit;
double ArcTolerance;
private:
Paths m_destPolys;
Path m_srcPoly;
Path m_destPoly;
std::vector<DoublePoint> m_normals;
double m_delta, m_sinA, m_sin, m_cos;
double m_miterLim, m_StepsPerRad;
IntPoint m_lowest;
PolyNode m_polyNodes;
void FixOrientations();
void DoOffset(double delta);
void OffsetPoint(int j, int &k, JoinType jointype);
void DoSquare(int j, int k);
void DoMiter(int j, int k, double r);
void DoRound(int j, int k);
};
//------------------------------------------------------------------------------
class clipperException : public std::exception {
public:
clipperException(const char *description) : m_descr(description) {}
virtual ~clipperException() throw() {}
virtual const char *what() const throw() { return m_descr.c_str(); }
private:
std::string m_descr;
};
//------------------------------------------------------------------------------
} // ClipperLib namespace
#endif // clipper_hpp
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ocr_cls_process.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <iostream>
#include <vector>
const std::vector<int> CLS_IMAGE_SHAPE = {3, 48, 192};
cv::Mat cls_resize_img(const cv::Mat &img) {
int imgC = CLS_IMAGE_SHAPE[0];
int imgW = CLS_IMAGE_SHAPE[2];
int imgH = CLS_IMAGE_SHAPE[1];
float ratio = float(img.cols) / float(img.rows);
int resize_w = 0;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_CUBIC);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(imgW - resize_w),
cv::BORDER_CONSTANT, {0, 0, 0});
}
return resize_img;
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "common.h"
#include <opencv2/opencv.hpp>
#include <vector>
extern const std::vector<int> CLS_IMAGE_SHAPE;
cv::Mat cls_resize_img(const cv::Mat &img);
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ocr_crnn_process.h"
#include <cmath>
#include <cstring>
#include <fstream>
#include <iostream>
#include <iostream>
#include <vector>
const std::string CHARACTER_TYPE = "ch";
const int MAX_DICT_LENGTH = 6624;
const std::vector<int> REC_IMAGE_SHAPE = {3, 32, 320};
static cv::Mat crnn_resize_norm_img(cv::Mat img, float wh_ratio) {
int imgC = REC_IMAGE_SHAPE[0];
int imgW = REC_IMAGE_SHAPE[2];
int imgH = REC_IMAGE_SHAPE[1];
if (CHARACTER_TYPE == "ch")
imgW = int(32 * wh_ratio);
float ratio = float(img.cols) / float(img.rows);
int resize_w = 0;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_CUBIC);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
for (int h = 0; h < resize_img.rows; h++) {
for (int w = 0; w < resize_img.cols; w++) {
resize_img.at<cv::Vec3f>(h, w)[0] =
(resize_img.at<cv::Vec3f>(h, w)[0] - 0.5) * 2;
resize_img.at<cv::Vec3f>(h, w)[1] =
(resize_img.at<cv::Vec3f>(h, w)[1] - 0.5) * 2;
resize_img.at<cv::Vec3f>(h, w)[2] =
(resize_img.at<cv::Vec3f>(h, w)[2] - 0.5) * 2;
}
}
cv::Mat dist;
cv::copyMakeBorder(resize_img, dist, 0, 0, 0, int(imgW - resize_w),
cv::BORDER_CONSTANT, {0, 0, 0});
return dist;
}
cv::Mat crnn_resize_img(const cv::Mat &img, float wh_ratio) {
int imgC = REC_IMAGE_SHAPE[0];
int imgW = REC_IMAGE_SHAPE[2];
int imgH = REC_IMAGE_SHAPE[1];
if (CHARACTER_TYPE == "ch") {
imgW = int(32 * wh_ratio);
}
float ratio = float(img.cols) / float(img.rows);
int resize_w = 0;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH));
return resize_img;
}
cv::Mat get_rotate_crop_image(const cv::Mat &srcimage,
const std::vector<std::vector<int>> &box) {
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
cv::Mat img_crop;
srcimage(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
/*
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
*/
cv::transpose(dst_img, dst_img);
cv::flip(dst_img, dst_img, 0);
return dst_img;
} else {
return dst_img;
}
}
//
// Created by fujiayi on 2020/7/3.
//
#pragma once
#include "common.h"
#include <opencv2/opencv.hpp>
#include <vector>
extern const std::vector<int> REC_IMAGE_SHAPE;
cv::Mat get_rotate_crop_image(const cv::Mat &srcimage,
const std::vector<std::vector<int>> &box);
cv::Mat crnn_resize_img(const cv::Mat &img, float wh_ratio);
template <class ForwardIterator>
inline size_t argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ocr_clipper.hpp"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include <iostream>
#include <math.h>
#include <vector>
static void getcontourarea(float **box, float unclip_ratio, float &distance) {
int pts_num = 4;
float area = 0.0f;
float dist = 0.0f;
for (int i = 0; i < pts_num; i++) {
area += box[i][0] * box[(i + 1) % pts_num][1] -
box[i][1] * box[(i + 1) % pts_num][0];
dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
(box[i][0] - box[(i + 1) % pts_num][0]) +
(box[i][1] - box[(i + 1) % pts_num][1]) *
(box[i][1] - box[(i + 1) % pts_num][1]));
}
area = fabs(float(area / 2.0));
distance = area * unclip_ratio / dist;
}
static cv::RotatedRect unclip(float **box) {
float unclip_ratio = 2.0;
float distance = 1.0;
getcontourarea(box, unclip_ratio, distance);
ClipperLib::ClipperOffset offset;
ClipperLib::Path p;
p << ClipperLib::IntPoint(int(box[0][0]), int(box[0][1]))
<< ClipperLib::IntPoint(int(box[1][0]), int(box[1][1]))
<< ClipperLib::IntPoint(int(box[2][0]), int(box[2][1]))
<< ClipperLib::IntPoint(int(box[3][0]), int(box[3][1]));
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths soln;
offset.Execute(soln, distance);
std::vector<cv::Point2f> points;
for (int j = 0; j < soln.size(); j++) {
for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
points.emplace_back(soln[j][i].X, soln[j][i].Y);
}
}
cv::RotatedRect res = cv::minAreaRect(points);
return res;
}
static float **Mat2Vec(cv::Mat mat) {
auto **array = new float *[mat.rows];
for (int i = 0; i < mat.rows; ++i) {
array[i] = new float[mat.cols];
}
for (int i = 0; i < mat.rows; ++i) {
for (int j = 0; j < mat.cols; ++j) {
array[i][j] = mat.at<float>(i, j);
}
}
return array;
}
static void quickSort(float **s, int l, int r) {
if (l < r) {
int i = l, j = r;
float x = s[l][0];
float *xp = s[l];
while (i < j) {
while (i < j && s[j][0] >= x) {
j--;
}
if (i < j) {
std::swap(s[i++], s[j]);
}
while (i < j && s[i][0] < x) {
i++;
}
if (i < j) {
std::swap(s[j--], s[i]);
}
}
s[i] = xp;
quickSort(s, l, i - 1);
quickSort(s, i + 1, r);
}
}
static void quickSort_vector(std::vector<std::vector<int>> &box, int l, int r,
int axis) {
if (l < r) {
int i = l, j = r;
int x = box[l][axis];
std::vector<int> xp(box[l]);
while (i < j) {
while (i < j && box[j][axis] >= x) {
j--;
}
if (i < j) {
std::swap(box[i++], box[j]);
}
while (i < j && box[i][axis] < x) {
i++;
}
if (i < j) {
std::swap(box[j--], box[i]);
}
}
box[i] = xp;
quickSort_vector(box, l, i - 1, axis);
quickSort_vector(box, i + 1, r, axis);
}
}
static std::vector<std::vector<int>>
order_points_clockwise(std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
quickSort_vector(box, 0, int(box.size() - 1), 0);
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
if (leftmost[0][1] > leftmost[1][1]) {
std::swap(leftmost[0], leftmost[1]);
}
if (rightmost[0][1] > rightmost[1][1]) {
std::swap(rightmost[0], rightmost[1]);
}
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
leftmost[1]};
return rect;
}
static float **get_mini_boxes(cv::RotatedRect box, float &ssid) {
ssid = box.size.width >= box.size.height ? box.size.height : box.size.width;
cv::Mat points;
cv::boxPoints(box, points);
// sorted box points
auto array = Mat2Vec(points);
quickSort(array, 0, 3);
float *idx1 = array[0], *idx2 = array[1], *idx3 = array[2], *idx4 = array[3];
if (array[3][1] <= array[2][1]) {
idx2 = array[3];
idx3 = array[2];
} else {
idx2 = array[2];
idx3 = array[3];
}
if (array[1][1] <= array[0][1]) {
idx1 = array[1];
idx4 = array[0];
} else {
idx1 = array[0];
idx4 = array[1];
}
array[0] = idx1;
array[1] = idx2;
array[2] = idx3;
array[3] = idx4;
return array;
}
template <class T> T clamp(T x, T min, T max) {
if (x > max) {
return max;
}
if (x < min) {
return min;
}
return x;
}
static float clampf(float x, float min, float max) {
if (x > max)
return max;
if (x < min)
return min;
return x;
}
float box_score_fast(float **box_array, cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
int height = pred.rows;
float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};
int xmin = clamp(int(std::floorf(*(std::min_element(box_x, box_x + 4)))), 0,
width - 1);
int xmax = clamp(int(std::ceilf(*(std::max_element(box_x, box_x + 4)))), 0,
width - 1);
int ymin = clamp(int(std::floorf(*(std::min_element(box_y, box_y + 4)))), 0,
height - 1);
int ymax = clamp(int(std::ceilf(*(std::max_element(box_y, box_y + 4)))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point root_point[4];
root_point[0] = cv::Point(int(array[0][0]) - xmin, int(array[0][1]) - ymin);
root_point[1] = cv::Point(int(array[1][0]) - xmin, int(array[1][1]) - ymin);
root_point[2] = cv::Point(int(array[2][0]) - xmin, int(array[2][1]) - ymin);
root_point[3] = cv::Point(int(array[3][0]) - xmin, int(array[3][1]) - ymin);
const cv::Point *ppt[1] = {root_point};
int npt[] = {4};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
auto score = cv::mean(croppedImg, mask)[0];
return score;
}
std::vector<std::vector<std::vector<int>>>
boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap) {
const int min_size = 3;
const int max_candidates = 1000;
const float box_thresh = 0.5;
int width = bitmap.cols;
int height = bitmap.rows;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
cv::CHAIN_APPROX_SIMPLE);
int num_contours =
contours.size() >= max_candidates ? max_candidates : contours.size();
std::vector<std::vector<std::vector<int>>> boxes;
for (int _i = 0; _i < num_contours; _i++) {
float ssid;
cv::RotatedRect box = cv::minAreaRect(contours[_i]);
auto array = get_mini_boxes(box, ssid);
auto box_for_unclip = array;
// end get_mini_box
if (ssid < min_size) {
continue;
}
float score;
score = box_score_fast(array, pred);
// end box_score_fast
if (score < box_thresh) {
continue;
}
// start for unclip
cv::RotatedRect points = unclip(box_for_unclip);
// end for unclip
cv::RotatedRect clipbox = points;
auto cliparray = get_mini_boxes(clipbox, ssid);
if (ssid < min_size + 2)
continue;
int dest_width = pred.cols;
int dest_height = pred.rows;
std::vector<std::vector<int>> intcliparray;
for (int num_pt = 0; num_pt < 4; num_pt++) {
std::vector<int> a{int(clampf(roundf(cliparray[num_pt][0] / float(width) *
float(dest_width)),
0, float(dest_width))),
int(clampf(roundf(cliparray[num_pt][1] /
float(height) * float(dest_height)),
0, float(dest_height)))};
intcliparray.emplace_back(std::move(a));
}
boxes.emplace_back(std::move(intcliparray));
} // end for
return boxes;
}
int _max(int a, int b) { return a >= b ? a : b; }
int _min(int a, int b) { return a >= b ? b : a; }
std::vector<std::vector<std::vector<int>>>
filter_tag_det_res(const std::vector<std::vector<std::vector<int>>> &o_boxes,
float ratio_h, float ratio_w, const cv::Mat &srcimg) {
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
std::vector<std::vector<std::vector<int>>> boxes{o_boxes};
std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < boxes.size(); n++) {
boxes[n] = order_points_clockwise(boxes[n]);
for (int m = 0; m < boxes[0].size(); m++) {
boxes[n][m][0] /= ratio_w;
boxes[n][m][1] /= ratio_h;
boxes[n][m][0] = int(_min(_max(boxes[n][m][0], 0), oriimg_w - 1));
boxes[n][m][1] = int(_min(_max(boxes[n][m][1], 0), oriimg_h - 1));
}
}
for (int n = 0; n < boxes.size(); n++) {
int rect_width, rect_height;
rect_width = int(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 10 || rect_height <= 10)
continue;
root_points.push_back(boxes[n]);
}
return root_points;
}
\ No newline at end of file
//
// Created by fujiayi on 2020/7/2.
//
#pragma once
#include <opencv2/opencv.hpp>
#include <vector>
std::vector<std::vector<std::vector<int>>>
boxes_from_bitmap(const cv::Mat &pred, const cv::Mat &bitmap);
std::vector<std::vector<std::vector<int>>>
filter_tag_det_res(const std::vector<std::vector<std::vector<int>>> &o_boxes,
float ratio_h, float ratio_w, const cv::Mat &srcimg);
\ No newline at end of file
//
// Created by fujiayi on 2020/7/1.
//
#include "ocr_ppredictor.h"
#include "common.h"
#include "ocr_cls_process.h"
#include "ocr_crnn_process.h"
#include "ocr_db_post_process.h"
#include "preprocess.h"
namespace ppredictor {
OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {}
int OCR_PPredictor::init(const std::string &det_model_content,
const std::string &rec_model_content,
const std::string &cls_model_content) {
_det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_nb(det_model_content);
_rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_nb(rec_model_content);
_cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_nb(cls_model_content);
return RETURN_OK;
}
int OCR_PPredictor::init_from_file(const std::string &det_model_path,
const std::string &rec_model_path,
const std::string &cls_model_path) {
_det_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
_det_predictor->init_from_file(det_model_path);
_rec_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_rec_predictor->init_from_file(rec_model_path);
_cls_predictor = std::unique_ptr<PPredictor>(
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
_cls_predictor->init_from_file(cls_model_path);
return RETURN_OK;
}
/**
* for debug use, show result of First Step
* @param filter_boxes
* @param boxes
* @param srcimg
*/
static void
visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &srcimg) {
// visualization
cv::Point rook_points[filter_boxes.size()][4];
for (int n = 0; n < filter_boxes.size(); n++) {
for (int m = 0; m < filter_boxes[0].size(); m++) {
rook_points[n][m] =
cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1]));
}
}
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
const cv::Point *ppt[1] = {rook_points[n]};
int npt[] = {4};
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
}
// 调试用,自行替换需要修改的路径
cv::imwrite("/sdcard/1/vis.png", img_vis);
}
std::vector<OCRPredictResult>
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims,
const float *input_data, int input_len, int net_flag,
cv::Mat &origin) {
PredictorInput input = _det_predictor->get_first_input();
input.set_dims(dims);
input.set_data(input_data, input_len);
std::vector<PredictorOutput> results = _det_predictor->infer();
PredictorOutput &res = results.at(0);
std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes(
res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin);
LOGI("Filter_box size %ld", filtered_box.size());
return infer_rec(filtered_box, origin);
}
std::vector<OCRPredictResult> OCR_PPredictor::infer_rec(
const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin_img) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _rec_predictor->get_first_input();
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
const std::vector<std::vector<int>> &box = *bp;
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
crop_img = infer_cls(crop_img);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims);
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
std::vector<PredictorOutput> results = _rec_predictor->infer();
const float *predict_batch = results.at(0).get_float_data();
const std::vector<int64_t> predict_shape = results.at(0).get_shape();
OCRPredictResult res;
// ctc decode
int argmax_idx;
int last_index = 0;
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < predict_shape[1]; n++) {
argmax_idx = int(argmax(&predict_batch[n * predict_shape[2]],
&predict_batch[(n + 1) * predict_shape[2]]));
max_value =
float(*std::max_element(&predict_batch[n * predict_shape[2]],
&predict_batch[(n + 1) * predict_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
res.word_index.push_back(argmax_idx);
}
last_index = argmax_idx;
}
score /= count;
if (res.word_index.empty()) {
continue;
}
res.score = score;
res.points = box;
ocr_results.emplace_back(std::move(res));
}
LOGI("ocr_results finished %lu", ocr_results.size());
return ocr_results;
}
cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
std::vector<int64_t> dims = {1, 3, 0, 0};
std::vector<OCRPredictResult> ocr_results;
PredictorInput input = _cls_predictor->get_first_input();
cv::Mat input_image = cls_resize_img(img);
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
const float *dimg = reinterpret_cast<const float *>(input_image.data);
int input_size = input_image.rows * input_image.cols;
dims[2] = input_image.rows;
dims[3] = input_image.cols;
input.set_dims(dims);
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
scale);
std::vector<PredictorOutput> results = _cls_predictor->infer();
const float *scores = results.at(0).get_float_data();
float score = 0;
int label = 0;
for (int64_t i = 0; i < results.at(0).get_size(); i++) {
LOGI("output scores [%f]", scores[i]);
if (scores[i] > score) {
score = scores[i];
label = i;
}
}
cv::Mat srcimg;
img.copyTo(srcimg);
if (label % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1);
}
return srcimg;
}
std::vector<std::vector<std::vector<int>>>
OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size,
int output_height, int output_width,
const cv::Mat &origin) {
const double threshold = 0.3;
const double maxvalue = 1;
cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F);
memcpy(pred_map.data, pred, pred_size * sizeof(float));
cv::Mat cbuf_map;
pred_map.convertTo(cbuf_map, CV_8UC1);
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
std::vector<std::vector<std::vector<int>>> boxes =
boxes_from_bitmap(pred_map, bit_map);
float ratio_h = output_height * 1.0f / origin.rows;
float ratio_w = output_width * 1.0f / origin.cols;
std::vector<std::vector<std::vector<int>>> filter_boxes =
filter_tag_det_res(boxes, ratio_h, ratio_w, origin);
return filter_boxes;
}
std::vector<int>
OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) {
const int *rec_idx = res.get_int_data();
const std::vector<std::vector<uint64_t>> rec_idx_lod = res.get_lod();
std::vector<int> pred_idx;
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) {
pred_idx.emplace_back(rec_idx[n]);
}
return pred_idx;
}
float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) {
const float *predict_batch = res.get_float_data();
const std::vector<int64_t> predict_shape = res.get_shape();
const std::vector<std::vector<uint64_t>> predict_lod = res.get_lod();
int blank = predict_shape[1];
float score = 0.f;
int count = 0;
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
int argmax_idx = argmax(predict_batch + n * predict_shape[1],
predict_batch + (n + 1) * predict_shape[1]);
float max_value = predict_batch[n * predict_shape[1] + argmax_idx];
if (blank - 1 - argmax_idx > 1e-5) {
score += max_value;
count += 1;
}
}
if (count == 0) {
LOGE("calc score count 0");
} else {
score /= count;
}
LOGI("calc score: %f", score);
return score;
}
NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; }
}
\ No newline at end of file
//
// Created by fujiayi on 2020/7/1.
//
#pragma once
#include "ppredictor.h"
#include <opencv2/opencv.hpp>
#include <paddle_api.h>
#include <string>
namespace ppredictor {
/**
* Config
*/
struct OCR_Config {
int thread_num = 4; // Thread num
paddle::lite_api::PowerMode mode =
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
};
/**
* PolyGone Result
*/
struct OCRPredictResult {
std::vector<int> word_index;
std::vector<std::vector<int>> points;
float score;
};
/**
* OCR there are 2 models
* 1. First model(det),select polygones to show where are the texts
* 2. crop from the origin images, use these polygones to infer
*/
class OCR_PPredictor : public PPredictor_Interface {
public:
OCR_PPredictor(const OCR_Config &config);
virtual ~OCR_PPredictor() {}
/**
* 初始化二个模型的Predictor
* @param det_model_content
* @param rec_model_content
* @return
*/
int init(const std::string &det_model_content,
const std::string &rec_model_content,
const std::string &cls_model_content);
int init_from_file(const std::string &det_model_path,
const std::string &rec_model_path,
const std::string &cls_model_path);
/**
* Return OCR result
* @param dims
* @param input_data
* @param input_len
* @param net_flag
* @param origin
* @return
*/
virtual std::vector<OCRPredictResult>
infer_ocr(const std::vector<int64_t> &dims, const float *input_data,
int input_len, int net_flag, cv::Mat &origin);
virtual NET_TYPE get_net_flag() const;
private:
/**
* calcul Polygone from the result image of first model
* @param pred
* @param output_height
* @param output_width
* @param origin
* @return
*/
std::vector<std::vector<std::vector<int>>>
calc_filtered_boxes(const float *pred, int pred_size, int output_height,
int output_width, const cv::Mat &origin);
/**
* infer for second model
*
* @param boxes
* @param origin
* @return
*/
std::vector<OCRPredictResult>
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes,
const cv::Mat &origin);
/**
* infer for cls model
*
* @param boxes
* @param origin
* @return
*/
cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.9);
/**
* Postprocess or sencod model to extract text
* @param res
* @return
*/
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
/**
* calculate confidence of second model text result
* @param res
* @return
*/
float postprocess_rec_score(const PredictorOutput &res);
std::unique_ptr<PPredictor> _det_predictor;
std::unique_ptr<PPredictor> _rec_predictor;
std::unique_ptr<PPredictor> _cls_predictor;
OCR_Config _config;
};
}
#include "ppredictor.h"
#include "common.h"
namespace ppredictor {
PPredictor::PPredictor(int thread_num, int net_flag,
paddle::lite_api::PowerMode mode)
: _thread_num(thread_num), _net_flag(net_flag), _mode(mode) {}
int PPredictor::init_nb(const std::string &model_content) {
paddle::lite_api::MobileConfig config;
config.set_model_from_buffer(model_content);
return _init(config);
}
int PPredictor::init_from_file(const std::string &model_content) {
paddle::lite_api::MobileConfig config;
config.set_model_from_file(model_content);
return _init(config);
}
template <typename ConfigT> int PPredictor::_init(ConfigT &config) {
config.set_threads(_thread_num);
config.set_power_mode(_mode);
_predictor = paddle::lite_api::CreatePaddlePredictor(config);
LOGI("paddle instance created");
return RETURN_OK;
}
PredictorInput PPredictor::get_input(int index) {
PredictorInput input{_predictor->GetInput(index), index, _net_flag};
_is_input_get = true;
return input;
}
std::vector<PredictorInput> PPredictor::get_inputs(int num) {
std::vector<PredictorInput> results;
for (int i = 0; i < num; i++) {
results.emplace_back(get_input(i));
}
return results;
}
PredictorInput PPredictor::get_first_input() { return get_input(0); }
std::vector<PredictorOutput> PPredictor::infer() {
LOGI("infer Run start %d", _net_flag);
std::vector<PredictorOutput> results;
if (!_is_input_get) {
return results;
}
_predictor->Run();
LOGI("infer Run end");
for (int i = 0; i < _predictor->GetOutputNames().size(); i++) {
std::unique_ptr<const paddle::lite_api::Tensor> output_tensor =
_predictor->GetOutput(i);
LOGI("output tensor[%d] size %ld", i, product(output_tensor->shape()));
PredictorOutput result{std::move(output_tensor), i, _net_flag};
results.emplace_back(std::move(result));
}
return results;
}
NET_TYPE PPredictor::get_net_flag() const { return (NET_TYPE)_net_flag; }
}
\ No newline at end of file
#pragma once
#include "paddle_api.h"
#include "predictor_input.h"
#include "predictor_output.h"
namespace ppredictor {
/**
* PaddleLite Preditor Common Interface
*/
class PPredictor_Interface {
public:
virtual ~PPredictor_Interface() {}
virtual NET_TYPE get_net_flag() const = 0;
};
/**
* Common Predictor
*/
class PPredictor : public PPredictor_Interface {
public:
PPredictor(
int thread_num, int net_flag = 0,
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH);
virtual ~PPredictor() {}
/**
* init paddlitelite opt model,nb format ,or use ini_paddle
* @param model_content
* @return 0
*/
virtual int init_nb(const std::string &model_content);
virtual int init_from_file(const std::string &model_content);
std::vector<PredictorOutput> infer();
std::shared_ptr<paddle::lite_api::PaddlePredictor> get_predictor() {
return _predictor;
}
virtual std::vector<PredictorInput> get_inputs(int num);
virtual PredictorInput get_input(int index);
virtual PredictorInput get_first_input();
virtual NET_TYPE get_net_flag() const;
protected:
template <typename ConfigT> int _init(ConfigT &config);
private:
int _thread_num;
paddle::lite_api::PowerMode _mode;
std::shared_ptr<paddle::lite_api::PaddlePredictor> _predictor;
bool _is_input_get = false;
int _net_flag;
};
}
#include "predictor_input.h"
namespace ppredictor {
void PredictorInput::set_dims(std::vector<int64_t> dims) {
// yolov3
if (_net_flag == 101 && _index == 1) {
_tensor->Resize({1, 2});
_tensor->mutable_data<int>()[0] = (int)dims.at(2);
_tensor->mutable_data<int>()[1] = (int)dims.at(3);
} else {
_tensor->Resize(dims);
}
_is_dims_set = true;
}
float *PredictorInput::get_mutable_float_data() {
if (!_is_dims_set) {
LOGE("PredictorInput::set_dims is not called");
}
return _tensor->mutable_data<float>();
}
void PredictorInput::set_data(const float *input_data, int input_float_len) {
float *input_raw_data = get_mutable_float_data();
memcpy(input_raw_data, input_data, input_float_len * sizeof(float));
}
}
\ No newline at end of file
#pragma once
#include "common.h"
#include <paddle_api.h>
#include <vector>
namespace ppredictor {
class PredictorInput {
public:
PredictorInput(std::unique_ptr<paddle::lite_api::Tensor> &&tensor, int index,
int net_flag)
: _tensor(std::move(tensor)), _index(index), _net_flag(net_flag) {}
void set_dims(std::vector<int64_t> dims);
float *get_mutable_float_data();
void set_data(const float *input_data, int input_float_len);
private:
std::unique_ptr<paddle::lite_api::Tensor> _tensor;
bool _is_dims_set = false;
int _index;
int _net_flag;
};
}
#include "predictor_output.h"
namespace ppredictor {
const float *PredictorOutput::get_float_data() const {
return _tensor->data<float>();
}
const int *PredictorOutput::get_int_data() const {
return _tensor->data<int>();
}
const std::vector<std::vector<uint64_t>> PredictorOutput::get_lod() const {
return _tensor->lod();
}
int64_t PredictorOutput::get_size() const {
if (_net_flag == NET_OCR) {
return _tensor->shape().at(2) * _tensor->shape().at(3);
} else {
return product(_tensor->shape());
}
}
const std::vector<int64_t> PredictorOutput::get_shape() const {
return _tensor->shape();
}
}
\ No newline at end of file
#pragma once
#include "common.h"
#include <paddle_api.h>
#include <vector>
namespace ppredictor {
class PredictorOutput {
public:
PredictorOutput() {}
PredictorOutput(std::unique_ptr<const paddle::lite_api::Tensor> &&tensor,
int index, int net_flag)
: _tensor(std::move(tensor)), _index(index), _net_flag(net_flag) {}
const float *get_float_data() const;
const int *get_int_data() const;
int64_t get_size() const;
const std::vector<std::vector<uint64_t>> get_lod() const;
const std::vector<int64_t> get_shape() const;
std::vector<float> data; // return float, or use data_int
std::vector<int> data_int; // several layers return int ,or use data
std::vector<int64_t> shape; // PaddleLite output shape
std::vector<std::vector<uint64_t>> lod; // PaddleLite output lod
private:
std::unique_ptr<const paddle::lite_api::Tensor> _tensor;
int _index;
int _net_flag;
};
}
#include "preprocess.h"
#include <android/bitmap.h>
cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap) {
AndroidBitmapInfo info;
int result = AndroidBitmap_getInfo(env, bitmap, &info);
if (result != ANDROID_BITMAP_RESULT_SUCCESS) {
LOGE("AndroidBitmap_getInfo failed, result: %d", result);
return cv::Mat{};
}
if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888) {
LOGE("Bitmap format is not RGBA_8888 !");
return cv::Mat{};
}
unsigned char *srcData = NULL;
AndroidBitmap_lockPixels(env, bitmap, (void **)&srcData);
cv::Mat mat = cv::Mat::zeros(info.height, info.width, CV_8UC4);
memcpy(mat.data, srcData, info.height * info.width * 4);
AndroidBitmap_unlockPixels(env, bitmap);
cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGR);
/**
if (!cv::imwrite("/sdcard/1/copy.jpg", mat)){
LOGE("Write image failed " );
}
*/
return mat;
}
cv::Mat resize_img(const cv::Mat &img, int height, int width) {
if (img.rows == height && img.cols == width) {
return img;
}
cv::Mat new_img;
cv::resize(img, new_img, cv::Size(height, width));
return new_img;
}
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void neon_mean_scale(const float *din, float *dout, int size,
const std::vector<float> &mean,
const std::vector<float> &scale) {
if (mean.size() != 3 || scale.size() != 3) {
LOGE("[ERROR] mean or scale size must equal to 3");
return;
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(scale[0]);
float32x4_t vscale1 = vdupq_n_f32(scale[1]);
float32x4_t vscale2 = vdupq_n_f32(scale[2]);
float *dout_c0 = dout;
float *dout_c1 = dout + size;
float *dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c1++) = (*(din++) - mean[1]) * scale[1];
*(dout_c2++) = (*(din++) - mean[2]) * scale[2];
}
}
\ No newline at end of file
#pragma once
#include "common.h"
#include <jni.h>
#include <opencv2/opencv.hpp>
cv::Mat bitmap_to_cv_mat(JNIEnv *env, jobject bitmap);
cv::Mat resize_img(const cv::Mat &img, int height, int width);
void neon_mean_scale(const float *din, float *dout, int size,
const std::vector<float> &mean,
const std::vector<float> &scale);
/*
* Copyright (C) 2014 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baidu.paddle.lite.demo.ocr;
import android.content.res.Configuration;
import android.os.Bundle;
import android.preference.PreferenceActivity;
import android.view.MenuInflater;
import android.view.View;
import android.view.ViewGroup;
import androidx.annotation.LayoutRes;
import androidx.annotation.Nullable;
import androidx.appcompat.app.ActionBar;
import androidx.appcompat.app.AppCompatDelegate;
import androidx.appcompat.widget.Toolbar;
/**
* A {@link PreferenceActivity} which implements and proxies the necessary calls
* to be used with AppCompat.
* <p>
* This technique can be used with an {@link android.app.Activity} class, not just
* {@link PreferenceActivity}.
*/
public abstract class AppCompatPreferenceActivity extends PreferenceActivity {
private AppCompatDelegate mDelegate;
@Override
protected void onCreate(Bundle savedInstanceState) {
getDelegate().installViewFactory();
getDelegate().onCreate(savedInstanceState);
super.onCreate(savedInstanceState);
}
@Override
protected void onPostCreate(Bundle savedInstanceState) {
super.onPostCreate(savedInstanceState);
getDelegate().onPostCreate(savedInstanceState);
}
public ActionBar getSupportActionBar() {
return getDelegate().getSupportActionBar();
}
public void setSupportActionBar(@Nullable Toolbar toolbar) {
getDelegate().setSupportActionBar(toolbar);
}
@Override
public MenuInflater getMenuInflater() {
return getDelegate().getMenuInflater();
}
@Override
public void setContentView(@LayoutRes int layoutResID) {
getDelegate().setContentView(layoutResID);
}
@Override
public void setContentView(View view) {
getDelegate().setContentView(view);
}
@Override
public void setContentView(View view, ViewGroup.LayoutParams params) {
getDelegate().setContentView(view, params);
}
@Override
public void addContentView(View view, ViewGroup.LayoutParams params) {
getDelegate().addContentView(view, params);
}
@Override
protected void onPostResume() {
super.onPostResume();
getDelegate().onPostResume();
}
@Override
protected void onTitleChanged(CharSequence title, int color) {
super.onTitleChanged(title, color);
getDelegate().setTitle(title);
}
@Override
public void onConfigurationChanged(Configuration newConfig) {
super.onConfigurationChanged(newConfig);
getDelegate().onConfigurationChanged(newConfig);
}
@Override
protected void onStop() {
super.onStop();
getDelegate().onStop();
}
@Override
protected void onDestroy() {
super.onDestroy();
getDelegate().onDestroy();
}
public void invalidateOptionsMenu() {
getDelegate().invalidateOptionsMenu();
}
private AppCompatDelegate getDelegate() {
if (mDelegate == null) {
mDelegate = AppCompatDelegate.create(this, null);
}
return mDelegate;
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.Manifest;
import android.app.ProgressDialog;
import android.content.ContentResolver;
import android.content.Context;
import android.content.Intent;
import android.content.SharedPreferences;
import android.content.pm.PackageManager;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.drawable.BitmapDrawable;
import android.media.ExifInterface;
import android.content.res.AssetManager;
import android.net.Uri;
import android.os.Bundle;
import android.os.Environment;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Message;
import android.preference.PreferenceManager;
import android.provider.MediaStore;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
import android.view.Menu;
import android.view.MenuInflater;
import android.view.MenuItem;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;
import androidx.core.content.FileProvider;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.text.SimpleDateFormat;
import java.util.Date;
public class MainActivity extends AppCompatActivity {
private static final String TAG = MainActivity.class.getSimpleName();
public static final int OPEN_GALLERY_REQUEST_CODE = 0;
public static final int TAKE_PHOTO_REQUEST_CODE = 1;
public static final int REQUEST_LOAD_MODEL = 0;
public static final int REQUEST_RUN_MODEL = 1;
public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
public static final int RESPONSE_RUN_MODEL_FAILED = 3;
protected ProgressDialog pbLoadModel = null;
protected ProgressDialog pbRunModel = null;
protected Handler receiver = null; // Receive messages from worker thread
protected Handler sender = null; // Send command to worker thread
protected HandlerThread worker = null; // Worker thread to load&run model
// UI components of object detection
protected TextView tvInputSetting;
protected TextView tvStatus;
protected ImageView ivInputImage;
protected TextView tvOutputResult;
protected TextView tvInferenceTime;
// Model settings of object detection
protected String modelPath = "";
protected String labelPath = "";
protected String imagePath = "";
protected int cpuThreadNum = 1;
protected String cpuPowerMode = "";
protected String inputColorFormat = "";
protected long[] inputShape = new long[]{};
protected float[] inputMean = new float[]{};
protected float[] inputStd = new float[]{};
protected float scoreThreshold = 0.1f;
private String currentPhotoPath;
private AssetManager assetManager =null;
protected Predictor predictor = new Predictor();
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// Clear all setting items to avoid app crashing due to the incorrect settings
SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
SharedPreferences.Editor editor = sharedPreferences.edit();
editor.clear();
editor.apply();
// Setup the UI components
tvInputSetting = findViewById(R.id.tv_input_setting);
tvStatus = findViewById(R.id.tv_model_img_status);
ivInputImage = findViewById(R.id.iv_input_image);
tvInferenceTime = findViewById(R.id.tv_inference_time);
tvOutputResult = findViewById(R.id.tv_output_result);
tvInputSetting.setMovementMethod(ScrollingMovementMethod.getInstance());
tvOutputResult.setMovementMethod(ScrollingMovementMethod.getInstance());
// Prepare the worker thread for mode loading and inference
receiver = new Handler() {
@Override
public void handleMessage(Message msg) {
switch (msg.what) {
case RESPONSE_LOAD_MODEL_SUCCESSED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){
pbLoadModel.dismiss();
}
onLoadModelSuccessed();
break;
case RESPONSE_LOAD_MODEL_FAILED:
if(pbLoadModel!=null && pbLoadModel.isShowing()){
pbLoadModel.dismiss();
}
Toast.makeText(MainActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
onLoadModelFailed();
break;
case RESPONSE_RUN_MODEL_SUCCESSED:
if(pbRunModel!=null && pbRunModel.isShowing()){
pbRunModel.dismiss();
}
onRunModelSuccessed();
break;
case RESPONSE_RUN_MODEL_FAILED:
if(pbRunModel!=null && pbRunModel.isShowing()){
pbRunModel.dismiss();
}
Toast.makeText(MainActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
onRunModelFailed();
break;
default:
break;
}
}
};
worker = new HandlerThread("Predictor Worker");
worker.start();
sender = new Handler(worker.getLooper()) {
public void handleMessage(Message msg) {
switch (msg.what) {
case REQUEST_LOAD_MODEL:
// Load model and reload test image
if (onLoadModel()) {
receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESSED);
} else {
receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED);
}
break;
case REQUEST_RUN_MODEL:
// Run model if model is loaded
if (onRunModel()) {
receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESSED);
} else {
receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED);
}
break;
default:
break;
}
}
};
}
@Override
protected void onResume() {
super.onResume();
SharedPreferences sharedPreferences = PreferenceManager.getDefaultSharedPreferences(this);
boolean settingsChanged = false;
String model_path = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY),
getString(R.string.MODEL_PATH_DEFAULT));
String label_path = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY),
getString(R.string.LABEL_PATH_DEFAULT));
String image_path = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY),
getString(R.string.IMAGE_PATH_DEFAULT));
settingsChanged |= !model_path.equalsIgnoreCase(modelPath);
settingsChanged |= !label_path.equalsIgnoreCase(labelPath);
settingsChanged |= !image_path.equalsIgnoreCase(imagePath);
int cpu_thread_num = Integer.parseInt(sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
getString(R.string.CPU_THREAD_NUM_DEFAULT)));
settingsChanged |= cpu_thread_num != cpuThreadNum;
String cpu_power_mode =
sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY),
getString(R.string.CPU_POWER_MODE_DEFAULT));
settingsChanged |= !cpu_power_mode.equalsIgnoreCase(cpuPowerMode);
String input_color_format =
sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
settingsChanged |= !input_color_format.equalsIgnoreCase(inputColorFormat);
long[] input_shape =
Utils.parseLongsFromString(sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY),
getString(R.string.INPUT_SHAPE_DEFAULT)), ",");
float[] input_mean =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY),
getString(R.string.INPUT_MEAN_DEFAULT)), ",");
float[] input_std =
Utils.parseFloatsFromString(sharedPreferences.getString(getString(R.string.INPUT_STD_KEY)
, getString(R.string.INPUT_STD_DEFAULT)), ",");
settingsChanged |= input_shape.length != inputShape.length;
settingsChanged |= input_mean.length != inputMean.length;
settingsChanged |= input_std.length != inputStd.length;
if (!settingsChanged) {
for (int i = 0; i < input_shape.length; i++) {
settingsChanged |= input_shape[i] != inputShape[i];
}
for (int i = 0; i < input_mean.length; i++) {
settingsChanged |= input_mean[i] != inputMean[i];
}
for (int i = 0; i < input_std.length; i++) {
settingsChanged |= input_std[i] != inputStd[i];
}
}
float score_threshold =
Float.parseFloat(sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY),
getString(R.string.SCORE_THRESHOLD_DEFAULT)));
settingsChanged |= scoreThreshold != score_threshold;
if (settingsChanged) {
modelPath = model_path;
labelPath = label_path;
imagePath = image_path;
cpuThreadNum = cpu_thread_num;
cpuPowerMode = cpu_power_mode;
inputColorFormat = input_color_format;
inputShape = input_shape;
inputMean = input_mean;
inputStd = input_std;
scoreThreshold = score_threshold;
// Update UI
tvInputSetting.setText("Model: " + modelPath.substring(modelPath.lastIndexOf("/") + 1) + "\n" + "CPU" +
" Thread Num: " + Integer.toString(cpuThreadNum) + "\n" + "CPU Power Mode: " + cpuPowerMode);
tvInputSetting.scrollTo(0, 0);
// Reload model if configure has been changed
// loadModel();
set_img();
}
}
public void loadModel() {
pbLoadModel = ProgressDialog.show(this, "", "loading model...", false, false);
sender.sendEmptyMessage(REQUEST_LOAD_MODEL);
}
public void runModel() {
pbRunModel = ProgressDialog.show(this, "", "running model...", false, false);
sender.sendEmptyMessage(REQUEST_RUN_MODEL);
}
public boolean onLoadModel() {
return predictor.init(MainActivity.this, modelPath, labelPath, cpuThreadNum,
cpuPowerMode,
inputColorFormat,
inputShape, inputMean,
inputStd, scoreThreshold);
}
public boolean onRunModel() {
return predictor.isLoaded() && predictor.runModel();
}
public void onLoadModelSuccessed() {
// Load test image from path and run model
tvStatus.setText("STATUS: load model successed");
}
public void onLoadModelFailed() {
tvStatus.setText("STATUS: load model failed");
}
public void onRunModelSuccessed() {
tvStatus.setText("STATUS: run model successed");
// Obtain results and update UI
tvInferenceTime.setText("Inference time: " + predictor.inferenceTime() + " ms");
Bitmap outputImage = predictor.outputImage();
if (outputImage != null) {
ivInputImage.setImageBitmap(outputImage);
}
tvOutputResult.setText(predictor.outputResult());
tvOutputResult.scrollTo(0, 0);
}
public void onRunModelFailed() {
tvStatus.setText("STATUS: run model failed");
}
public void onImageChanged(Bitmap image) {
// Rerun model if users pick test image from gallery or camera
if (image != null && predictor.isLoaded()) {
predictor.setInputImage(image);
runModel();
}
}
public void set_img() {
// Load test image from path and run model
try {
assetManager= getAssets();
InputStream in=assetManager.open(imagePath);
Bitmap bmp=BitmapFactory.decodeStream(in);
ivInputImage.setImageBitmap(bmp);
} catch (IOException e) {
Toast.makeText(MainActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
e.printStackTrace();
}
}
public void onSettingsClicked() {
startActivity(new Intent(MainActivity.this, SettingsActivity.class));
}
@Override
public boolean onCreateOptionsMenu(Menu menu) {
MenuInflater inflater = getMenuInflater();
inflater.inflate(R.menu.menu_action_options, menu);
return true;
}
public boolean onPrepareOptionsMenu(Menu menu) {
boolean isLoaded = predictor.isLoaded();
return super.onPrepareOptionsMenu(menu);
}
@Override
public boolean onOptionsItemSelected(MenuItem item) {
switch (item.getItemId()) {
case android.R.id.home:
finish();
break;
case R.id.settings:
if (requestAllPermissions()) {
// Make sure we have SDCard r&w permissions to load model from SDCard
onSettingsClicked();
}
break;
}
return super.onOptionsItemSelected(item);
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions,
@NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) {
Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show();
}
}
private boolean requestAllPermissions() {
if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE)
!= PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this,
Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE,
Manifest.permission.CAMERA},
0);
return false;
}
return true;
}
private void openGallery() {
Intent intent = new Intent(Intent.ACTION_PICK, null);
intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE);
}
private void takePhoto() {
Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
// Ensure that there's a camera activity to handle the intent
if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
// Create the File where the photo should go
File photoFile = null;
try {
photoFile = createImageFile();
} catch (IOException ex) {
Log.e("MainActitity", ex.getMessage(), ex);
Toast.makeText(MainActivity.this,
"Create Camera temp file failed: " + ex.getMessage(), Toast.LENGTH_SHORT).show();
}
// Continue only if the File was successfully created
if (photoFile != null) {
Log.i(TAG, "FILEPATH " + getExternalFilesDir("Pictures").getAbsolutePath());
Uri photoURI = FileProvider.getUriForFile(this,
"com.baidu.paddle.lite.demo.ocr.fileprovider",
photoFile);
currentPhotoPath = photoFile.getAbsolutePath();
takePictureIntent.putExtra(MediaStore.EXTRA_OUTPUT, photoURI);
startActivityForResult(takePictureIntent, TAKE_PHOTO_REQUEST_CODE);
Log.i(TAG, "startActivityForResult finished");
}
}
}
private File createImageFile() throws IOException {
// Create an image file name
String timeStamp = new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date());
String imageFileName = "JPEG_" + timeStamp + "_";
File storageDir = getExternalFilesDir(Environment.DIRECTORY_PICTURES);
File image = File.createTempFile(
imageFileName, /* prefix */
".bmp", /* suffix */
storageDir /* directory */
);
return image;
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode == RESULT_OK) {
switch (requestCode) {
case OPEN_GALLERY_REQUEST_CODE:
if (data == null) {
break;
}
try {
ContentResolver resolver = getContentResolver();
Uri uri = data.getData();
Bitmap image = MediaStore.Images.Media.getBitmap(resolver, uri);
String[] proj = {MediaStore.Images.Media.DATA};
Cursor cursor = managedQuery(uri, proj, null, null, null);
cursor.moveToFirst();
if (image != null) {
// onImageChanged(image);
ivInputImage.setImageBitmap(image);
}
} catch (IOException e) {
Log.e(TAG, e.toString());
}
break;
case TAKE_PHOTO_REQUEST_CODE:
if (currentPhotoPath != null) {
ExifInterface exif = null;
try {
exif = new ExifInterface(currentPhotoPath);
} catch (IOException e) {
e.printStackTrace();
}
int orientation = exif.getAttributeInt(ExifInterface.TAG_ORIENTATION,
ExifInterface.ORIENTATION_UNDEFINED);
Log.i(TAG, "rotation " + orientation);
Bitmap image = BitmapFactory.decodeFile(currentPhotoPath);
image = Utils.rotateBitmap(image, orientation);
if (image != null) {
// onImageChanged(image);
ivInputImage.setImageBitmap(image);
}
} else {
Log.e(TAG, "currentPhotoPath is null");
}
break;
default:
break;
}
}
}
public void btn_load_model_click(View view) {
tvStatus.setText("STATUS: load model ......");
loadModel();
}
public void btn_run_model_click(View view) {
Bitmap image =((BitmapDrawable)ivInputImage.getDrawable()).getBitmap();
if(image == null) {
tvStatus.setText("STATUS: image is not exists");
}
else if (!predictor.isLoaded()){
tvStatus.setText("STATUS: model is not loaded");
}else{
tvStatus.setText("STATUS: run model ...... ");
predictor.setInputImage(image);
runModel();
}
}
public void btn_choice_img_click(View view) {
if (requestAllPermissions()) {
openGallery();
}
}
public void btn_take_photo_click(View view) {
if (requestAllPermissions()) {
takePhoto();
}
}
@Override
protected void onDestroy() {
if (predictor != null) {
predictor.releaseModel();
}
worker.quit();
super.onDestroy();
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Build;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Message;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import android.widget.Toast;
import androidx.appcompat.app.AppCompatActivity;
import java.io.IOException;
import java.io.InputStream;
public class MiniActivity extends AppCompatActivity {
public static final int REQUEST_LOAD_MODEL = 0;
public static final int REQUEST_RUN_MODEL = 1;
public static final int REQUEST_UNLOAD_MODEL = 2;
public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
public static final int RESPONSE_RUN_MODEL_FAILED = 3;
private static final String TAG = "MiniActivity";
protected Handler receiver = null; // Receive messages from worker thread
protected Handler sender = null; // Send command to worker thread
protected HandlerThread worker = null; // Worker thread to load&run model
protected volatile Predictor predictor = null;
private String assetModelDirPath = "models/ocr_v2_for_cpu";
private String assetlabelFilePath = "labels/ppocr_keys_v1.txt";
private Button button;
private ImageView imageView; // image result
private TextView textView; // text result
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_mini);
Log.i(TAG, "SHOW in Logcat");
// Prepare the worker thread for mode loading and inference
worker = new HandlerThread("Predictor Worker");
worker.start();
sender = new Handler(worker.getLooper()) {
public void handleMessage(Message msg) {
switch (msg.what) {
case REQUEST_LOAD_MODEL:
// Load model and reload test image
if (!onLoadModel()) {
runOnUiThread(new Runnable() {
@Override
public void run() {
Toast.makeText(MiniActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
}
});
}
break;
case REQUEST_RUN_MODEL:
// Run model if model is loaded
final boolean isSuccessed = onRunModel();
runOnUiThread(new Runnable() {
@Override
public void run() {
if (isSuccessed){
onRunModelSuccessed();
}else{
Toast.makeText(MiniActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
}
}
});
break;
}
}
};
sender.sendEmptyMessage(REQUEST_LOAD_MODEL); // corresponding to REQUEST_LOAD_MODEL, to call onLoadModel()
imageView = findViewById(R.id.imageView);
textView = findViewById(R.id.sample_text);
button = findViewById(R.id.button);
button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View v) {
sender.sendEmptyMessage(REQUEST_RUN_MODEL);
}
});
}
@Override
protected void onDestroy() {
onUnloadModel();
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.JELLY_BEAN_MR2) {
worker.quitSafely();
} else {
worker.quit();
}
super.onDestroy();
}
/**
* call in onCreate, model init
*
* @return
*/
private boolean onLoadModel() {
if (predictor == null) {
predictor = new Predictor();
}
return predictor.init(this, assetModelDirPath, assetlabelFilePath);
}
/**
* init engine
* call in onCreate
*
* @return
*/
private boolean onRunModel() {
try {
String assetImagePath = "images/0.jpg";
InputStream imageStream = getAssets().open(assetImagePath);
Bitmap image = BitmapFactory.decodeStream(imageStream);
// Input is Bitmap
predictor.setInputImage(image);
return predictor.isLoaded() && predictor.runModel();
} catch (IOException e) {
e.printStackTrace();
return false;
}
}
private void onRunModelSuccessed() {
Log.i(TAG, "onRunModelSuccessed");
textView.setText(predictor.outputResult);
imageView.setImageBitmap(predictor.outputImage);
}
private void onUnloadModel() {
if (predictor != null) {
predictor.releaseModel();
}
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.graphics.Bitmap;
import android.util.Log;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicBoolean;
public class OCRPredictorNative {
private static final AtomicBoolean isSOLoaded = new AtomicBoolean();
public static void loadLibrary() throws RuntimeException {
if (!isSOLoaded.get() && isSOLoaded.compareAndSet(false, true)) {
try {
System.loadLibrary("Native");
} catch (Throwable e) {
RuntimeException exception = new RuntimeException(
"Load libNative.so failed, please check it exists in apk file.", e);
throw exception;
}
}
}
private Config config;
private long nativePointer = 0;
public OCRPredictorNative(Config config) {
this.config = config;
loadLibrary();
nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename,
config.cpuThreadNum, config.cpuPower);
Log.i("OCRPredictorNative", "load success " + nativePointer);
}
public ArrayList<OcrResultModel> runImage(float[] inputData, int width, int height, int channels, Bitmap originalImage) {
Log.i("OCRPredictorNative", "begin to run image " + inputData.length + " " + width + " " + height);
float[] dims = new float[]{1, channels, height, width};
float[] rawResults = forward(nativePointer, inputData, dims, originalImage);
ArrayList<OcrResultModel> results = postprocess(rawResults);
return results;
}
public static class Config {
public int cpuThreadNum;
public String cpuPower;
public String detModelFilename;
public String recModelFilename;
public String clsModelFilename;
}
public void destory(){
if (nativePointer > 0) {
release(nativePointer);
nativePointer = 0;
}
}
protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode);
protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage);
protected native void release(long pointer);
private ArrayList<OcrResultModel> postprocess(float[] raw) {
ArrayList<OcrResultModel> results = new ArrayList<OcrResultModel>();
int begin = 0;
while (begin < raw.length) {
int point_num = Math.round(raw[begin]);
int word_num = Math.round(raw[begin + 1]);
OcrResultModel model = parse(raw, begin + 2, point_num, word_num);
begin += 2 + 1 + point_num * 2 + word_num;
results.add(model);
}
return results;
}
private OcrResultModel parse(float[] raw, int begin, int pointNum, int wordNum) {
int current = begin;
OcrResultModel model = new OcrResultModel();
model.setConfidence(raw[current]);
current++;
for (int i = 0; i < pointNum; i++) {
model.addPoints(Math.round(raw[current + i * 2]), Math.round(raw[current + i * 2 + 1]));
}
current += (pointNum * 2);
for (int i = 0; i < wordNum; i++) {
int index = Math.round(raw[current + i]);
model.addWordIndex(index);
}
Log.i("OCRPredictorNative", "word finished " + wordNum);
return model;
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.graphics.Point;
import java.util.ArrayList;
import java.util.List;
public class OcrResultModel {
private List<Point> points;
private List<Integer> wordIndex;
private String label;
private float confidence;
public OcrResultModel() {
super();
points = new ArrayList<>();
wordIndex = new ArrayList<>();
}
public void addPoints(int x, int y) {
Point point = new Point(x, y);
points.add(point);
}
public void addWordIndex(int index) {
wordIndex.add(index);
}
public List<Point> getPoints() {
return points;
}
public List<Integer> getWordIndex() {
return wordIndex;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.graphics.Point;
import android.util.Log;
import java.io.File;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Vector;
import static android.graphics.Color.*;
public class Predictor {
private static final String TAG = Predictor.class.getSimpleName();
public boolean isLoaded = false;
public int warmupIterNum = 1;
public int inferIterNum = 1;
public int cpuThreadNum = 4;
public String cpuPowerMode = "LITE_POWER_HIGH";
public String modelPath = "";
public String modelName = "";
protected OCRPredictorNative paddlePredictor = null;
protected float inferenceTime = 0;
// Only for object detection
protected Vector<String> wordLabels = new Vector<String>();
protected String inputColorFormat = "BGR";
protected long[] inputShape = new long[]{1, 3, 960};
protected float[] inputMean = new float[]{0.485f, 0.456f, 0.406f};
protected float[] inputStd = new float[]{1.0f / 0.229f, 1.0f / 0.224f, 1.0f / 0.225f};
protected float scoreThreshold = 0.1f;
protected Bitmap inputImage = null;
protected Bitmap outputImage = null;
protected volatile String outputResult = "";
protected float preprocessTime = 0;
protected float postprocessTime = 0;
public Predictor() {
}
public boolean init(Context appCtx, String modelPath, String labelPath) {
isLoaded = loadModel(appCtx, modelPath, cpuThreadNum, cpuPowerMode);
if (!isLoaded) {
return false;
}
isLoaded = loadLabel(appCtx, labelPath);
return isLoaded;
}
public boolean init(Context appCtx, String modelPath, String labelPath, int cpuThreadNum, String cpuPowerMode,
String inputColorFormat,
long[] inputShape, float[] inputMean,
float[] inputStd, float scoreThreshold) {
if (inputShape.length != 3) {
Log.e(TAG, "Size of input shape should be: 3");
return false;
}
if (inputMean.length != inputShape[1]) {
Log.e(TAG, "Size of input mean should be: " + Long.toString(inputShape[1]));
return false;
}
if (inputStd.length != inputShape[1]) {
Log.e(TAG, "Size of input std should be: " + Long.toString(inputShape[1]));
return false;
}
if (inputShape[0] != 1) {
Log.e(TAG, "Only one batch is supported in the image classification demo, you can use any batch size in " +
"your Apps!");
return false;
}
if (inputShape[1] != 1 && inputShape[1] != 3) {
Log.e(TAG, "Only one/three channels are supported in the image classification demo, you can use any " +
"channel size in your Apps!");
return false;
}
if (!inputColorFormat.equalsIgnoreCase("BGR")) {
Log.e(TAG, "Only BGR color format is supported.");
return false;
}
boolean isLoaded = init(appCtx, modelPath, labelPath);
if (!isLoaded) {
return false;
}
this.inputColorFormat = inputColorFormat;
this.inputShape = inputShape;
this.inputMean = inputMean;
this.inputStd = inputStd;
this.scoreThreshold = scoreThreshold;
return true;
}
protected boolean loadModel(Context appCtx, String modelPath, int cpuThreadNum, String cpuPowerMode) {
// Release model if exists
releaseModel();
// Load model
if (modelPath.isEmpty()) {
return false;
}
String realPath = modelPath;
if (!modelPath.substring(0, 1).equals("/")) {
// Read model files from custom path if the first character of mode path is '/'
// otherwise copy model to cache from assets
realPath = appCtx.getCacheDir() + "/" + modelPath;
Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath);
}
if (realPath.isEmpty()) {
return false;
}
OCRPredictorNative.Config config = new OCRPredictorNative.Config();
config.cpuThreadNum = cpuThreadNum;
config.detModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_det_opt.nb";
config.recModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_rec_opt.nb";
config.clsModelFilename = realPath + File.separator + "ch_ppocr_mobile_v2.0_cls_opt.nb";
Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename);
config.cpuPower = cpuPowerMode;
paddlePredictor = new OCRPredictorNative(config);
this.cpuThreadNum = cpuThreadNum;
this.cpuPowerMode = cpuPowerMode;
this.modelPath = realPath;
this.modelName = realPath.substring(realPath.lastIndexOf("/") + 1);
return true;
}
public void releaseModel() {
if (paddlePredictor != null) {
paddlePredictor.destory();
paddlePredictor = null;
}
isLoaded = false;
cpuThreadNum = 1;
cpuPowerMode = "LITE_POWER_HIGH";
modelPath = "";
modelName = "";
}
protected boolean loadLabel(Context appCtx, String labelPath) {
wordLabels.clear();
wordLabels.add("black");
// Load word labels from file
try {
InputStream assetsInputStream = appCtx.getAssets().open(labelPath);
int available = assetsInputStream.available();
byte[] lines = new byte[available];
assetsInputStream.read(lines);
assetsInputStream.close();
String words = new String(lines);
String[] contents = words.split("\n");
for (String content : contents) {
wordLabels.add(content);
}
Log.i(TAG, "Word label size: " + wordLabels.size());
} catch (Exception e) {
Log.e(TAG, e.getMessage());
return false;
}
return true;
}
public boolean runModel() {
if (inputImage == null || !isLoaded()) {
return false;
}
// Pre-process image, and feed input tensor with pre-processed data
Bitmap scaleImage = Utils.resizeWithStep(inputImage, Long.valueOf(inputShape[2]).intValue(), 32);
Date start = new Date();
int channels = (int) inputShape[1];
int width = scaleImage.getWidth();
int height = scaleImage.getHeight();
float[] inputData = new float[channels * width * height];
if (channels == 3) {
int[] channelIdx = null;
if (inputColorFormat.equalsIgnoreCase("RGB")) {
channelIdx = new int[]{0, 1, 2};
} else if (inputColorFormat.equalsIgnoreCase("BGR")) {
channelIdx = new int[]{2, 1, 0};
} else {
Log.i(TAG, "Unknown color format " + inputColorFormat + ", only RGB and BGR color format is " +
"supported!");
return false;
}
int[] channelStride = new int[]{width * height, width * height * 2};
int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1);
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int color = scaleImage.getPixel(x, y);
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
(float) blue(color) / 255.0f};
inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
}
}
} else if (channels == 1) {
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int color = inputImage.getPixel(x, y);
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0];
}
}
} else {
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
"supported!");
return false;
}
float[] pixels = inputData;
Log.i(TAG, "pixels " + pixels[0] + " " + pixels[1] + " " + pixels[2] + " " + pixels[3]
+ " " + pixels[pixels.length / 2] + " " + pixels[pixels.length / 2 + 1] + " " + pixels[pixels.length - 2] + " " + pixels[pixels.length - 1]);
Date end = new Date();
preprocessTime = (float) (end.getTime() - start.getTime());
// Warm up
for (int i = 0; i < warmupIterNum; i++) {
paddlePredictor.runImage(inputData, width, height, channels, inputImage);
}
warmupIterNum = 0; // do not need warm
// Run inference
start = new Date();
ArrayList<OcrResultModel> results = paddlePredictor.runImage(inputData, width, height, channels, inputImage);
end = new Date();
inferenceTime = (end.getTime() - start.getTime()) / (float) inferIterNum;
results = postprocess(results);
Log.i(TAG, "[stat] Preprocess Time: " + preprocessTime
+ " ; Inference Time: " + inferenceTime + " ;Box Size " + results.size());
drawResults(results);
return true;
}
public boolean isLoaded() {
return paddlePredictor != null && isLoaded;
}
public String modelPath() {
return modelPath;
}
public String modelName() {
return modelName;
}
public int cpuThreadNum() {
return cpuThreadNum;
}
public String cpuPowerMode() {
return cpuPowerMode;
}
public float inferenceTime() {
return inferenceTime;
}
public Bitmap inputImage() {
return inputImage;
}
public Bitmap outputImage() {
return outputImage;
}
public String outputResult() {
return outputResult;
}
public float preprocessTime() {
return preprocessTime;
}
public float postprocessTime() {
return postprocessTime;
}
public void setInputImage(Bitmap image) {
if (image == null) {
return;
}
this.inputImage = image.copy(Bitmap.Config.ARGB_8888, true);
}
private ArrayList<OcrResultModel> postprocess(ArrayList<OcrResultModel> results) {
for (OcrResultModel r : results) {
StringBuffer word = new StringBuffer();
for (int index : r.getWordIndex()) {
if (index >= 0 && index < wordLabels.size()) {
word.append(wordLabels.get(index));
} else {
Log.e(TAG, "Word index is not in label list:" + index);
word.append("×");
}
}
r.setLabel(word.toString());
}
return results;
}
private void drawResults(ArrayList<OcrResultModel> results) {
StringBuffer outputResultSb = new StringBuffer("");
for (int i = 0; i < results.size(); i++) {
OcrResultModel result = results.get(i);
StringBuilder sb = new StringBuilder("");
sb.append(result.getLabel());
sb.append(" ").append(result.getConfidence());
sb.append("; Points: ");
for (Point p : result.getPoints()) {
sb.append("(").append(p.x).append(",").append(p.y).append(") ");
}
Log.i(TAG, sb.toString()); // show LOG in Logcat panel
outputResultSb.append(i + 1).append(": ").append(result.getLabel()).append("\n");
}
outputResult = outputResultSb.toString();
outputImage = inputImage;
Canvas canvas = new Canvas(outputImage);
Paint paintFillAlpha = new Paint();
paintFillAlpha.setStyle(Paint.Style.FILL);
paintFillAlpha.setColor(Color.parseColor("#3B85F5"));
paintFillAlpha.setAlpha(50);
Paint paint = new Paint();
paint.setColor(Color.parseColor("#3B85F5"));
paint.setStrokeWidth(5);
paint.setStyle(Paint.Style.STROKE);
for (OcrResultModel result : results) {
Path path = new Path();
List<Point> points = result.getPoints();
path.moveTo(points.get(0).x, points.get(0).y);
for (int i = points.size() - 1; i >= 0; i--) {
Point p = points.get(i);
path.lineTo(p.x, p.y);
}
canvas.drawPath(path, paint);
canvas.drawPath(path, paintFillAlpha);
}
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.content.SharedPreferences;
import android.os.Bundle;
import android.preference.CheckBoxPreference;
import android.preference.EditTextPreference;
import android.preference.ListPreference;
import androidx.appcompat.app.ActionBar;
import java.util.ArrayList;
import java.util.List;
public class SettingsActivity extends AppCompatPreferenceActivity implements SharedPreferences.OnSharedPreferenceChangeListener {
ListPreference lpChoosePreInstalledModel = null;
CheckBoxPreference cbEnableCustomSettings = null;
EditTextPreference etModelPath = null;
EditTextPreference etLabelPath = null;
ListPreference etImagePath = null;
ListPreference lpCPUThreadNum = null;
ListPreference lpCPUPowerMode = null;
ListPreference lpInputColorFormat = null;
EditTextPreference etInputShape = null;
EditTextPreference etInputMean = null;
EditTextPreference etInputStd = null;
EditTextPreference etScoreThreshold = null;
List<String> preInstalledModelPaths = null;
List<String> preInstalledLabelPaths = null;
List<String> preInstalledImagePaths = null;
List<String> preInstalledInputShapes = null;
List<String> preInstalledCPUThreadNums = null;
List<String> preInstalledCPUPowerModes = null;
List<String> preInstalledInputColorFormats = null;
List<String> preInstalledInputMeans = null;
List<String> preInstalledInputStds = null;
List<String> preInstalledScoreThresholds = null;
@Override
public void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
addPreferencesFromResource(R.xml.settings);
ActionBar supportActionBar = getSupportActionBar();
if (supportActionBar != null) {
supportActionBar.setDisplayHomeAsUpEnabled(true);
}
// Initialized pre-installed models
preInstalledModelPaths = new ArrayList<String>();
preInstalledLabelPaths = new ArrayList<String>();
preInstalledImagePaths = new ArrayList<String>();
preInstalledInputShapes = new ArrayList<String>();
preInstalledCPUThreadNums = new ArrayList<String>();
preInstalledCPUPowerModes = new ArrayList<String>();
preInstalledInputColorFormats = new ArrayList<String>();
preInstalledInputMeans = new ArrayList<String>();
preInstalledInputStds = new ArrayList<String>();
preInstalledScoreThresholds = new ArrayList<String>();
// Add ssd_mobilenet_v1_pascalvoc_for_cpu
preInstalledModelPaths.add(getString(R.string.MODEL_PATH_DEFAULT));
preInstalledLabelPaths.add(getString(R.string.LABEL_PATH_DEFAULT));
preInstalledImagePaths.add(getString(R.string.IMAGE_PATH_DEFAULT));
preInstalledCPUThreadNums.add(getString(R.string.CPU_THREAD_NUM_DEFAULT));
preInstalledCPUPowerModes.add(getString(R.string.CPU_POWER_MODE_DEFAULT));
preInstalledInputColorFormats.add(getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
preInstalledInputShapes.add(getString(R.string.INPUT_SHAPE_DEFAULT));
preInstalledInputMeans.add(getString(R.string.INPUT_MEAN_DEFAULT));
preInstalledInputStds.add(getString(R.string.INPUT_STD_DEFAULT));
preInstalledScoreThresholds.add(getString(R.string.SCORE_THRESHOLD_DEFAULT));
// Setup UI components
lpChoosePreInstalledModel =
(ListPreference) findPreference(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY));
String[] preInstalledModelNames = new String[preInstalledModelPaths.size()];
for (int i = 0; i < preInstalledModelPaths.size(); i++) {
preInstalledModelNames[i] =
preInstalledModelPaths.get(i).substring(preInstalledModelPaths.get(i).lastIndexOf("/") + 1);
}
lpChoosePreInstalledModel.setEntries(preInstalledModelNames);
lpChoosePreInstalledModel.setEntryValues(preInstalledModelPaths.toArray(new String[preInstalledModelPaths.size()]));
cbEnableCustomSettings =
(CheckBoxPreference) findPreference(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY));
etModelPath = (EditTextPreference) findPreference(getString(R.string.MODEL_PATH_KEY));
etModelPath.setTitle("Model Path (SDCard: " + Utils.getSDCardDirectory() + ")");
etLabelPath = (EditTextPreference) findPreference(getString(R.string.LABEL_PATH_KEY));
etImagePath = (ListPreference) findPreference(getString(R.string.IMAGE_PATH_KEY));
lpCPUThreadNum =
(ListPreference) findPreference(getString(R.string.CPU_THREAD_NUM_KEY));
lpCPUPowerMode =
(ListPreference) findPreference(getString(R.string.CPU_POWER_MODE_KEY));
lpInputColorFormat =
(ListPreference) findPreference(getString(R.string.INPUT_COLOR_FORMAT_KEY));
etInputShape = (EditTextPreference) findPreference(getString(R.string.INPUT_SHAPE_KEY));
etInputMean = (EditTextPreference) findPreference(getString(R.string.INPUT_MEAN_KEY));
etInputStd = (EditTextPreference) findPreference(getString(R.string.INPUT_STD_KEY));
etScoreThreshold = (EditTextPreference) findPreference(getString(R.string.SCORE_THRESHOLD_KEY));
}
private void reloadPreferenceAndUpdateUI() {
SharedPreferences sharedPreferences = getPreferenceScreen().getSharedPreferences();
boolean enableCustomSettings =
sharedPreferences.getBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false);
String modelPath = sharedPreferences.getString(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY),
getString(R.string.MODEL_PATH_DEFAULT));
int modelIdx = lpChoosePreInstalledModel.findIndexOfValue(modelPath);
if (modelIdx >= 0 && modelIdx < preInstalledModelPaths.size()) {
if (!enableCustomSettings) {
SharedPreferences.Editor editor = sharedPreferences.edit();
editor.putString(getString(R.string.MODEL_PATH_KEY), preInstalledModelPaths.get(modelIdx));
editor.putString(getString(R.string.LABEL_PATH_KEY), preInstalledLabelPaths.get(modelIdx));
editor.putString(getString(R.string.IMAGE_PATH_KEY), preInstalledImagePaths.get(modelIdx));
editor.putString(getString(R.string.CPU_THREAD_NUM_KEY), preInstalledCPUThreadNums.get(modelIdx));
editor.putString(getString(R.string.CPU_POWER_MODE_KEY), preInstalledCPUPowerModes.get(modelIdx));
editor.putString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
preInstalledInputColorFormats.get(modelIdx));
editor.putString(getString(R.string.INPUT_SHAPE_KEY), preInstalledInputShapes.get(modelIdx));
editor.putString(getString(R.string.INPUT_MEAN_KEY), preInstalledInputMeans.get(modelIdx));
editor.putString(getString(R.string.INPUT_STD_KEY), preInstalledInputStds.get(modelIdx));
editor.putString(getString(R.string.SCORE_THRESHOLD_KEY),
preInstalledScoreThresholds.get(modelIdx));
editor.apply();
}
lpChoosePreInstalledModel.setSummary(modelPath);
}
cbEnableCustomSettings.setChecked(enableCustomSettings);
etModelPath.setEnabled(enableCustomSettings);
etLabelPath.setEnabled(enableCustomSettings);
etImagePath.setEnabled(enableCustomSettings);
lpCPUThreadNum.setEnabled(enableCustomSettings);
lpCPUPowerMode.setEnabled(enableCustomSettings);
lpInputColorFormat.setEnabled(enableCustomSettings);
etInputShape.setEnabled(enableCustomSettings);
etInputMean.setEnabled(enableCustomSettings);
etInputStd.setEnabled(enableCustomSettings);
etScoreThreshold.setEnabled(enableCustomSettings);
modelPath = sharedPreferences.getString(getString(R.string.MODEL_PATH_KEY),
getString(R.string.MODEL_PATH_DEFAULT));
String labelPath = sharedPreferences.getString(getString(R.string.LABEL_PATH_KEY),
getString(R.string.LABEL_PATH_DEFAULT));
String imagePath = sharedPreferences.getString(getString(R.string.IMAGE_PATH_KEY),
getString(R.string.IMAGE_PATH_DEFAULT));
String cpuThreadNum = sharedPreferences.getString(getString(R.string.CPU_THREAD_NUM_KEY),
getString(R.string.CPU_THREAD_NUM_DEFAULT));
String cpuPowerMode = sharedPreferences.getString(getString(R.string.CPU_POWER_MODE_KEY),
getString(R.string.CPU_POWER_MODE_DEFAULT));
String inputColorFormat = sharedPreferences.getString(getString(R.string.INPUT_COLOR_FORMAT_KEY),
getString(R.string.INPUT_COLOR_FORMAT_DEFAULT));
String inputShape = sharedPreferences.getString(getString(R.string.INPUT_SHAPE_KEY),
getString(R.string.INPUT_SHAPE_DEFAULT));
String inputMean = sharedPreferences.getString(getString(R.string.INPUT_MEAN_KEY),
getString(R.string.INPUT_MEAN_DEFAULT));
String inputStd = sharedPreferences.getString(getString(R.string.INPUT_STD_KEY),
getString(R.string.INPUT_STD_DEFAULT));
String scoreThreshold = sharedPreferences.getString(getString(R.string.SCORE_THRESHOLD_KEY),
getString(R.string.SCORE_THRESHOLD_DEFAULT));
etModelPath.setSummary(modelPath);
etModelPath.setText(modelPath);
etLabelPath.setSummary(labelPath);
etLabelPath.setText(labelPath);
etImagePath.setSummary(imagePath);
etImagePath.setValue(imagePath);
lpCPUThreadNum.setValue(cpuThreadNum);
lpCPUThreadNum.setSummary(cpuThreadNum);
lpCPUPowerMode.setValue(cpuPowerMode);
lpCPUPowerMode.setSummary(cpuPowerMode);
lpInputColorFormat.setValue(inputColorFormat);
lpInputColorFormat.setSummary(inputColorFormat);
etInputShape.setSummary(inputShape);
etInputShape.setText(inputShape);
etInputMean.setSummary(inputMean);
etInputMean.setText(inputMean);
etInputStd.setSummary(inputStd);
etInputStd.setText(inputStd);
etScoreThreshold.setText(scoreThreshold);
etScoreThreshold.setSummary(scoreThreshold);
}
@Override
protected void onResume() {
super.onResume();
getPreferenceScreen().getSharedPreferences().registerOnSharedPreferenceChangeListener(this);
reloadPreferenceAndUpdateUI();
}
@Override
protected void onPause() {
super.onPause();
getPreferenceScreen().getSharedPreferences().unregisterOnSharedPreferenceChangeListener(this);
}
@Override
public void onSharedPreferenceChanged(SharedPreferences sharedPreferences, String key) {
if (key.equals(getString(R.string.CHOOSE_PRE_INSTALLED_MODEL_KEY))) {
SharedPreferences.Editor editor = sharedPreferences.edit();
editor.putBoolean(getString(R.string.ENABLE_CUSTOM_SETTINGS_KEY), false);
editor.commit();
}
reloadPreferenceAndUpdateUI();
}
}
package com.baidu.paddle.lite.demo.ocr;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.media.ExifInterface;
import android.os.Environment;
import java.io.*;
public class Utils {
private static final String TAG = Utils.class.getSimpleName();
public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
if (srcPath.isEmpty() || dstPath.isEmpty()) {
return;
}
InputStream is = null;
OutputStream os = null;
try {
is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
os = new BufferedOutputStream(new FileOutputStream(new File(dstPath)));
byte[] buffer = new byte[1024];
int length = 0;
while ((length = is.read(buffer)) != -1) {
os.write(buffer, 0, length);
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
os.close();
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
if (srcDir.isEmpty() || dstDir.isEmpty()) {
return;
}
try {
if (!new File(dstDir).exists()) {
new File(dstDir).mkdirs();
}
for (String fileName : appCtx.getAssets().list(srcDir)) {
String srcSubPath = srcDir + File.separator + fileName;
String dstSubPath = dstDir + File.separator + fileName;
if (new File(srcSubPath).isDirectory()) {
copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
} else {
copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
public static float[] parseFloatsFromString(String string, String delimiter) {
String[] pieces = string.trim().toLowerCase().split(delimiter);
float[] floats = new float[pieces.length];
for (int i = 0; i < pieces.length; i++) {
floats[i] = Float.parseFloat(pieces[i].trim());
}
return floats;
}
public static long[] parseLongsFromString(String string, String delimiter) {
String[] pieces = string.trim().toLowerCase().split(delimiter);
long[] longs = new long[pieces.length];
for (int i = 0; i < pieces.length; i++) {
longs[i] = Long.parseLong(pieces[i].trim());
}
return longs;
}
public static String getSDCardDirectory() {
return Environment.getExternalStorageDirectory().getAbsolutePath();
}
public static boolean isSupportedNPU() {
return false;
// String hardware = android.os.Build.HARDWARE;
// return hardware.equalsIgnoreCase("kirin810") || hardware.equalsIgnoreCase("kirin990");
}
public static Bitmap resizeWithStep(Bitmap bitmap, int maxLength, int step) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
int maxWH = Math.max(width, height);
float ratio = 1;
int newWidth = width;
int newHeight = height;
if (maxWH > maxLength) {
ratio = maxLength * 1.0f / maxWH;
newWidth = (int) Math.floor(ratio * width);
newHeight = (int) Math.floor(ratio * height);
}
newWidth = newWidth - newWidth % step;
if (newWidth == 0) {
newWidth = step;
}
newHeight = newHeight - newHeight % step;
if (newHeight == 0) {
newHeight = step;
}
return Bitmap.createScaledBitmap(bitmap, newWidth, newHeight, true);
}
public static Bitmap rotateBitmap(Bitmap bitmap, int orientation) {
Matrix matrix = new Matrix();
switch (orientation) {
case ExifInterface.ORIENTATION_NORMAL:
return bitmap;
case ExifInterface.ORIENTATION_FLIP_HORIZONTAL:
matrix.setScale(-1, 1);
break;
case ExifInterface.ORIENTATION_ROTATE_180:
matrix.setRotate(180);
break;
case ExifInterface.ORIENTATION_FLIP_VERTICAL:
matrix.setRotate(180);
matrix.postScale(-1, 1);
break;
case ExifInterface.ORIENTATION_TRANSPOSE:
matrix.setRotate(90);
matrix.postScale(-1, 1);
break;
case ExifInterface.ORIENTATION_ROTATE_90:
matrix.setRotate(90);
break;
case ExifInterface.ORIENTATION_TRANSVERSE:
matrix.setRotate(-90);
matrix.postScale(-1, 1);
break;
case ExifInterface.ORIENTATION_ROTATE_270:
matrix.setRotate(-90);
break;
default:
return bitmap;
}
try {
Bitmap bmRotated = Bitmap.createBitmap(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
bitmap.recycle();
return bmRotated;
}
catch (OutOfMemoryError e) {
e.printStackTrace();
return null;
}
}
}
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillType="evenOdd"
android:pathData="M32,64C32,64 38.39,52.99 44.13,50.95C51.37,48.37 70.14,49.57 70.14,49.57L108.26,87.69L108,109.01L75.97,107.97L32,64Z"
android:strokeWidth="1"
android:strokeColor="#00000000">
<aapt:attr name="android:fillColor">
<gradient
android:endX="78.5885"
android:endY="90.9159"
android:startX="48.7653"
android:startY="61.0927"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M66.94,46.02L66.94,46.02C72.44,50.07 76,56.61 76,64L32,64C32,56.61 35.56,50.11 40.98,46.06L36.18,41.19C35.45,40.45 35.45,39.3 36.18,38.56C36.91,37.81 38.05,37.81 38.78,38.56L44.25,44.05C47.18,42.57 50.48,41.71 54,41.71C57.48,41.71 60.78,42.57 63.68,44.05L69.11,38.56C69.84,37.81 70.98,37.81 71.71,38.56C72.44,39.3 72.44,40.45 71.71,41.19L66.94,46.02ZM62.94,56.92C64.08,56.92 65,56.01 65,54.88C65,53.76 64.08,52.85 62.94,52.85C61.8,52.85 60.88,53.76 60.88,54.88C60.88,56.01 61.8,56.92 62.94,56.92ZM45.06,56.92C46.2,56.92 47.13,56.01 47.13,54.88C47.13,53.76 46.2,52.85 45.06,52.85C43.92,52.85 43,53.76 43,54.88C43,56.01 43.92,56.92 45.06,56.92Z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>
<?xml version="1.0" encoding="utf-8"?>
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#008577"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<RelativeLayout
android:layout_width="match_parent"
android:layout_height="match_parent">
<LinearLayout
android:id="@+id/v_input_info"
android:layout_width="fill_parent"
android:layout_height="wrap_content"
android:layout_alignParentTop="true"
android:orientation="vertical">
<LinearLayout
android:id="@+id/btn_layout"
android:layout_width="fill_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<Button
android:id="@+id/btn_load_model"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:onClick="btn_load_model_click"
android:text="加载模型" />
<Button
android:id="@+id/btn_run_model"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:onClick="btn_run_model_click"
android:text="运行模型" />
<Button
android:id="@+id/btn_take_photo"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:onClick="btn_take_photo_click"
android:text="拍照识别" />
<Button
android:id="@+id/btn_choice_img"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:layout_weight="1"
android:onClick="btn_choice_img_click"
android:text="选取图片" />
</LinearLayout>
<TextView
android:id="@+id/tv_input_setting"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:scrollbars="vertical"
android:layout_marginLeft="12dp"
android:layout_marginRight="12dp"
android:layout_marginTop="10dp"
android:layout_marginBottom="5dp"
android:lineSpacingExtra="4dp"
android:singleLine="false"
android:maxLines="6"
android:text=""/>
<TextView
android:id="@+id/tv_model_img_status"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:scrollbars="vertical"
android:layout_marginLeft="12dp"
android:layout_marginRight="12dp"
android:layout_marginTop="-5dp"
android:layout_marginBottom="5dp"
android:lineSpacingExtra="4dp"
android:singleLine="false"
android:maxLines="6"
android:text="STATUS: ok"/>
</LinearLayout>
<RelativeLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_above="@+id/v_output_info"
android:layout_below="@+id/v_input_info">
<ImageView
android:id="@+id/iv_input_image"
android:layout_width="400dp"
android:layout_height="400dp"
android:layout_centerHorizontal="true"
android:layout_centerVertical="true"
android:layout_marginLeft="12dp"
android:layout_marginRight="12dp"
android:layout_marginTop="5dp"
android:layout_marginBottom="5dp"
android:adjustViewBounds="true"
android:scaleType="fitCenter"/>
</RelativeLayout>
<RelativeLayout
android:id="@+id/v_output_info"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentBottom="true"
android:layout_centerHorizontal="true">
<TextView
android:id="@+id/tv_output_result"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_alignParentTop="true"
android:layout_centerHorizontal="true"
android:layout_centerVertical="true"
android:scrollbars="vertical"
android:layout_marginLeft="12dp"
android:layout_marginRight="12dp"
android:layout_marginTop="5dp"
android:layout_marginBottom="5dp"
android:textAlignment="center"
android:lineSpacingExtra="5dp"
android:singleLine="false"
android:maxLines="5"
android:text=""/>
<TextView
android:id="@+id/tv_inference_time"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_below="@+id/tv_output_result"
android:layout_centerHorizontal="true"
android:layout_centerVertical="true"
android:textAlignment="center"
android:layout_marginLeft="12dp"
android:layout_marginRight="12dp"
android:layout_marginTop="5dp"
android:layout_marginBottom="10dp"
android:text=""/>
</RelativeLayout>
</RelativeLayout>
</androidx.constraintlayout.widget.ConstraintLayout>
<?xml version="1.0" encoding="utf-8"?>
<!-- for MiniActivity Use Only -->
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintLeft_toRightOf="parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/sample_text"
android:layout_width="0dp"
android:layout_height="wrap_content"
android:text="Hello World!"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toBottomOf="@id/imageView"
android:scrollbars="vertical"
/>
<ImageView
android:id="@+id/imageView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:paddingTop="20dp"
android:paddingBottom="20dp"
app:layout_constraintBottom_toTopOf="@id/imageView"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
app:layout_constraintTop_toTopOf="parent"
tools:srcCompat="@tools:sample/avatars" />
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginBottom="4dp"
android:text="Button"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintLeft_toLeftOf="parent"
app:layout_constraintRight_toRightOf="parent"
tools:layout_editor_absoluteX="161dp" />
</androidx.constraintlayout.widget.ConstraintLayout>
\ No newline at end of file
<menu xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto">
<group>
<item
android:id="@+id/settings"
android:title="Settings..."
app:showAsAction="withText"/>
</group>
</menu>
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<resources>
<string-array name="image_name_entries">
<item>0.jpg</item>
<item>90.jpg</item>
<item>180.jpg</item>
<item>270.jpg</item>
</string-array>
<string-array name="image_name_values">
<item>images/0.jpg</item>
<item>images/90.jpg</item>
<item>images/180.jpg</item>
<item>images/270.jpg</item>
</string-array>
<string-array name="cpu_thread_num_entries">
<item>1 threads</item>
<item>2 threads</item>
<item>4 threads</item>
<item>8 threads</item>
</string-array>
<string-array name="cpu_thread_num_values">
<item>1</item>
<item>2</item>
<item>4</item>
<item>8</item>
</string-array>
<string-array name="cpu_power_mode_entries">
<item>HIGH(only big cores)</item>
<item>LOW(only LITTLE cores)</item>
<item>FULL(all cores)</item>
<item>NO_BIND(depends on system)</item>
<item>RAND_HIGH</item>
<item>RAND_LOW</item>
</string-array>
<string-array name="cpu_power_mode_values">
<item>LITE_POWER_HIGH</item>
<item>LITE_POWER_LOW</item>
<item>LITE_POWER_FULL</item>
<item>LITE_POWER_NO_BIND</item>
<item>LITE_POWER_RAND_HIGH</item>
<item>LITE_POWER_RAND_LOW</item>
</string-array>
<string-array name="input_color_format_entries">
<item>BGR color format</item>
<item>RGB color format</item>
</string-array>
<string-array name="input_color_format_values">
<item>BGR</item>
<item>RGB</item>
</string-array>
</resources>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>
<resources>
<string name="app_name">OCR Chinese</string>
<string name="CHOOSE_PRE_INSTALLED_MODEL_KEY">CHOOSE_PRE_INSTALLED_MODEL_KEY</string>
<string name="ENABLE_CUSTOM_SETTINGS_KEY">ENABLE_CUSTOM_SETTINGS_KEY</string>
<string name="MODEL_PATH_KEY">MODEL_PATH_KEY</string>
<string name="LABEL_PATH_KEY">LABEL_PATH_KEY</string>
<string name="IMAGE_PATH_KEY">IMAGE_PATH_KEY</string>
<string name="CPU_THREAD_NUM_KEY">CPU_THREAD_NUM_KEY</string>
<string name="CPU_POWER_MODE_KEY">CPU_POWER_MODE_KEY</string>
<string name="INPUT_COLOR_FORMAT_KEY">INPUT_COLOR_FORMAT_KEY</string>
<string name="INPUT_SHAPE_KEY">INPUT_SHAPE_KEY</string>
<string name="INPUT_MEAN_KEY">INPUT_MEAN_KEY</string>
<string name="INPUT_STD_KEY">INPUT_STD_KEY</string>
<string name="SCORE_THRESHOLD_KEY">SCORE_THRESHOLD_KEY</string>
<string name="MODEL_PATH_DEFAULT">models/ocr_v2_for_cpu</string>
<string name="LABEL_PATH_DEFAULT">labels/ppocr_keys_v1.txt</string>
<string name="IMAGE_PATH_DEFAULT">images/0.jpg</string>
<string name="CPU_THREAD_NUM_DEFAULT">4</string>
<string name="CPU_POWER_MODE_DEFAULT">LITE_POWER_HIGH</string>
<string name="INPUT_COLOR_FORMAT_DEFAULT">BGR</string>
<string name="INPUT_SHAPE_DEFAULT">1,3,960</string>
<string name="INPUT_MEAN_DEFAULT">0.485, 0.456, 0.406</string>
<string name="INPUT_STD_DEFAULT">0.229,0.224,0.225</string>
<string name="SCORE_THRESHOLD_DEFAULT">0.1</string>
</resources>
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
<item name="actionOverflowMenuStyle">@style/OverflowMenuStyle</item>
</style>
<style name="OverflowMenuStyle" parent="Widget.AppCompat.Light.PopupMenu.Overflow">
<item name="overlapAnchor">false</item>
</style>
<style name="AppTheme.NoActionBar">
<item name="windowActionBar">false</item>
<item name="windowNoTitle">true</item>
</style>
<style name="AppTheme.AppBarOverlay" parent="ThemeOverlay.AppCompat.Dark.ActionBar"/>
<style name="AppTheme.PopupOverlay" parent="ThemeOverlay.AppCompat.Light"/>
</resources>
<?xml version="1.0" encoding="utf-8"?>
<paths xmlns:android="http://schemas.android.com/apk/res/android">
<external-files-path name="my_images" path="Pictures" />
</paths>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<PreferenceScreen xmlns:android="http://schemas.android.com/apk/res/android" >
<PreferenceCategory android:title="Model Settings">
<ListPreference
android:defaultValue="@string/MODEL_PATH_DEFAULT"
android:key="@string/CHOOSE_PRE_INSTALLED_MODEL_KEY"
android:negativeButtonText="@null"
android:positiveButtonText="@null"
android:title="Choose pre-installed models" />
<CheckBoxPreference
android:defaultValue="false"
android:key="@string/ENABLE_CUSTOM_SETTINGS_KEY"
android:summaryOn="Enable"
android:summaryOff="Disable"
android:title="Enable custom settings"/>
<EditTextPreference
android:key="@string/MODEL_PATH_KEY"
android:defaultValue="@string/MODEL_PATH_DEFAULT"
android:title="Model Path" />
<EditTextPreference
android:key="@string/LABEL_PATH_KEY"
android:defaultValue="@string/LABEL_PATH_DEFAULT"
android:title="Label Path" />
<ListPreference
android:key="@string/IMAGE_PATH_KEY"
android:defaultValue="@string/IMAGE_PATH_DEFAULT"
android:entries="@array/image_name_entries"
android:entryValues="@array/image_name_values"
android:title="Image Path" />
</PreferenceCategory>
<PreferenceCategory android:title="CPU Settings">
<ListPreference
android:defaultValue="@string/CPU_THREAD_NUM_DEFAULT"
android:key="@string/CPU_THREAD_NUM_KEY"
android:negativeButtonText="@null"
android:positiveButtonText="@null"
android:title="CPU Thread Num"
android:entries="@array/cpu_thread_num_entries"
android:entryValues="@array/cpu_thread_num_values"/>
<ListPreference
android:defaultValue="@string/CPU_POWER_MODE_DEFAULT"
android:key="@string/CPU_POWER_MODE_KEY"
android:negativeButtonText="@null"
android:positiveButtonText="@null"
android:title="CPU Power Mode"
android:entries="@array/cpu_power_mode_entries"
android:entryValues="@array/cpu_power_mode_values"/>
</PreferenceCategory>
<PreferenceCategory android:title="Input Settings">
<ListPreference
android:defaultValue="@string/INPUT_COLOR_FORMAT_DEFAULT"
android:key="@string/INPUT_COLOR_FORMAT_KEY"
android:negativeButtonText="@null"
android:positiveButtonText="@null"
android:title="Input Color Format: BGR or RGB"
android:entries="@array/input_color_format_entries"
android:entryValues="@array/input_color_format_values"/>
<EditTextPreference
android:key="@string/INPUT_SHAPE_KEY"
android:defaultValue="@string/INPUT_SHAPE_DEFAULT"
android:title="Input Shape: (1,1,max_width_height) or (1,3,max_width_height)" />
<EditTextPreference
android:key="@string/INPUT_MEAN_KEY"
android:defaultValue="@string/INPUT_MEAN_DEFAULT"
android:title="Input Mean: (channel/255-mean)/std" />
<EditTextPreference
android:key="@string/INPUT_STD_KEY"
android:defaultValue="@string/INPUT_STD_DEFAULT"
android:title="Input Std: (channel/255-mean)/std" />
</PreferenceCategory>
<PreferenceCategory android:title="Output Settings">
<EditTextPreference
android:key="@string/SCORE_THRESHOLD_KEY"
android:defaultValue="@string/SCORE_THRESHOLD_DEFAULT"
android:title="Score Threshold" />
</PreferenceCategory>
</PreferenceScreen>
package com.baidu.paddle.lite.demo.ocr;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Example local unit test, which will execute on the development machine (host).
*
* @see <a href="http://d.android.com/tools/testing">Testing documentation</a>
*/
public class ExampleUnitTest {
@Test
public void addition_isCorrect() {
assertEquals(4, 2 + 2);
}
}
\ No newline at end of file
// Top-level build file where you can add configuration options common to all sub-projects/modules.
buildscript {
repositories {
google()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:4.1.2'
// NOTE: Do not place your application dependencies here; they belong
// in the individual module build.gradle files
}
}
allprojects {
repositories {
google()
jcenter()
}
}
task clean(type: Delete) {
delete rootProject.buildDir
}
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx1536m
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
android.useAndroidX=true
#Thu Feb 04 20:28:08 CST 2021
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-bin.zip
#!/usr/bin/env sh
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin, switch paths to Windows format before running java
if $cygwin ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=$((i+1))
done
case $i in
(0) set -- ;;
(1) set -- "$args0" ;;
(2) set -- "$args0" "$args1" ;;
(3) set -- "$args0" "$args1" "$args2" ;;
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=$(save "$@")
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
cd "$(dirname "$0")"
fi
exec "$JAVACMD" "$@"
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS=
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:init
@rem Get command-line arguments, handling Windows variants
if not "%OS%" == "Windows_NT" goto win9xME_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega
...@@ -40,6 +40,7 @@ endif() ...@@ -40,6 +40,7 @@ endif()
if (WIN32) if (WIN32)
include_directories("${PADDLE_LIB}/paddle/fluid/inference") include_directories("${PADDLE_LIB}/paddle/fluid/inference")
include_directories("${PADDLE_LIB}/paddle/include") include_directories("${PADDLE_LIB}/paddle/include")
link_directories("${PADDLE_LIB}/paddle/lib")
link_directories("${PADDLE_LIB}/paddle/fluid/inference") link_directories("${PADDLE_LIB}/paddle/fluid/inference")
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH) find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
...@@ -140,22 +141,22 @@ else() ...@@ -140,22 +141,22 @@ else()
endif () endif ()
endif() endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a # Note: libpaddle_inference_api.so/a must put before libpaddle_inference.so/a
if(WITH_STATIC_LIB) if(WITH_STATIC_LIB)
if(WIN32) if(WIN32)
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/paddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
else() else()
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_STATIC_LIBRARY_SUFFIX})
endif() endif()
else() else()
if(WIN32) if(WIN32)
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/paddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/paddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
else() else()
set(DEPS set(DEPS
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX}) ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX})
endif() endif()
endif(WITH_STATIC_LIB) endif(WITH_STATIC_LIB)
......
...@@ -49,6 +49,8 @@ public: ...@@ -49,6 +49,8 @@ public:
this->det_db_unclip_ratio = stod(config_map_["det_db_unclip_ratio"]); this->det_db_unclip_ratio = stod(config_map_["det_db_unclip_ratio"]);
this->use_polygon_score = bool(stoi(config_map_["use_polygon_score"]));
this->det_model_dir.assign(config_map_["det_model_dir"]); this->det_model_dir.assign(config_map_["det_model_dir"]);
this->rec_model_dir.assign(config_map_["rec_model_dir"]); this->rec_model_dir.assign(config_map_["rec_model_dir"]);
...@@ -86,6 +88,8 @@ public: ...@@ -86,6 +88,8 @@ public:
double det_db_unclip_ratio = 2.0; double det_db_unclip_ratio = 2.0;
bool use_polygon_score = false;
std::string det_model_dir; std::string det_model_dir;
std::string rec_model_dir; std::string rec_model_dir;
......
...@@ -44,7 +44,8 @@ public: ...@@ -44,7 +44,8 @@ public:
const bool &use_mkldnn, const int &max_side_len, const bool &use_mkldnn, const int &max_side_len,
const double &det_db_thresh, const double &det_db_thresh,
const double &det_db_box_thresh, const double &det_db_box_thresh,
const double &det_db_unclip_ratio, const bool &visualize, const double &det_db_unclip_ratio,
const bool &use_polygon_score, const bool &visualize,
const bool &use_tensorrt, const bool &use_fp16) { const bool &use_tensorrt, const bool &use_fp16) {
this->use_gpu_ = use_gpu; this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id; this->gpu_id_ = gpu_id;
...@@ -57,6 +58,7 @@ public: ...@@ -57,6 +58,7 @@ public:
this->det_db_thresh_ = det_db_thresh; this->det_db_thresh_ = det_db_thresh;
this->det_db_box_thresh_ = det_db_box_thresh; this->det_db_box_thresh_ = det_db_box_thresh;
this->det_db_unclip_ratio_ = det_db_unclip_ratio; this->det_db_unclip_ratio_ = det_db_unclip_ratio;
this->use_polygon_score_ = use_polygon_score;
this->visualize_ = visualize; this->visualize_ = visualize;
this->use_tensorrt_ = use_tensorrt; this->use_tensorrt_ = use_tensorrt;
...@@ -85,6 +87,7 @@ private: ...@@ -85,6 +87,7 @@ private:
double det_db_thresh_ = 0.3; double det_db_thresh_ = 0.3;
double det_db_box_thresh_ = 0.5; double det_db_box_thresh_ = 0.5;
double det_db_unclip_ratio_ = 2.0; double det_db_unclip_ratio_ = 2.0;
bool use_polygon_score_ = false;
bool visualize_ = true; bool visualize_ = true;
bool use_tensorrt_ = false; bool use_tensorrt_ = false;
......
...@@ -51,10 +51,12 @@ public: ...@@ -51,10 +51,12 @@ public:
float &ssid); float &ssid);
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred); float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
const float &box_thresh, const float &det_db_unclip_ratio); const float &box_thresh, const float &det_db_unclip_ratio,
const bool &use_polygon_score);
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
......
...@@ -74,9 +74,10 @@ opencv3/ ...@@ -74,9 +74,10 @@ opencv3/
* 有2种方式获取Paddle预测库,下面进行详细介绍。 * 有2种方式获取Paddle预测库,下面进行详细介绍。
#### 1.2.1 直接下载安装 #### 1.2.1 直接下载安装
* [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本 * [Paddle预测库官网](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html)上提供了不同cuda版本的Linux预测库,可以在官网查看并选择合适的预测库版本(*建议选择paddle版本>=2.0.1版本的预测库*
* 下载之后使用下面的方法解压。 * 下载之后使用下面的方法解压。
...@@ -130,8 +131,6 @@ build/paddle_inference_install_dir/ ...@@ -130,8 +131,6 @@ build/paddle_inference_install_dir/
其中`paddle`就是C++预测所需的Paddle库,`version.txt`中包含当前预测库的版本信息。 其中`paddle`就是C++预测所需的Paddle库,`version.txt`中包含当前预测库的版本信息。
## 2 开始运行 ## 2 开始运行
### 2.1 将模型导出为inference model ### 2.1 将模型导出为inference model
...@@ -184,7 +183,7 @@ cmake .. \ ...@@ -184,7 +183,7 @@ cmake .. \
make -j make -j
``` ```
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中`/usr/local/cuda/lib64``CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/` `OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64``CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`
* 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。 * 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。
...@@ -212,6 +211,7 @@ max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使 ...@@ -212,6 +211,7 @@ max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使
det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显 det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显
det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小 det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小
det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本 det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本
use_polygon_score 1 # 是否使用多边形框计算bbox score,0表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。
det_model_dir ./inference/det_db # 检测模型inference model地址 det_model_dir ./inference/det_db # 检测模型inference model地址
# cls config # cls config
...@@ -232,7 +232,7 @@ visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹 ...@@ -232,7 +232,7 @@ visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹
最终屏幕上会输出检测结果如下。 最终屏幕上会输出检测结果如下。
<div align="center"> <div align="center">
<img src="../imgs/cpp_infer_pred_12.png" width="600"> <img src="./imgs/cpp_infer_pred_12.png" width="600">
</div> </div>
......
...@@ -91,8 +91,8 @@ tar -xf paddle_inference.tgz ...@@ -91,8 +91,8 @@ tar -xf paddle_inference.tgz
Finally you can see the following files in the folder of `paddle_inference/`. Finally you can see the following files in the folder of `paddle_inference/`.
#### 1.2.2 Compile from the source code #### 1.2.2 Compile from the source code
* If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code. * If you want to get the latest Paddle inference library features, you can download the latest code from Paddle github repository and compile the inference library from the source code. It is recommended to download the inference library with paddle version greater than or equal to 2.0.1.
* You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/develop/guides/05_inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows. * You can refer to [Paddle inference library] (https://www.paddlepaddle.org.cn/documentation/docs/en/advanced_guide/inference_deployment/inference/build_and_install_lib_en.html) to get the Paddle source code from github, and then compile To generate the latest inference library. The method of using git to access the code is as follows.
```shell ```shell
...@@ -217,6 +217,7 @@ max_side_len 960 # Limit the maximum image height and width to 960 ...@@ -217,6 +217,7 @@ max_side_len 960 # Limit the maximum image height and width to 960
det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result
det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate
det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text
use_polygon_score 1 # Whether to use polygon box to calculate bbox score, 0 means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.
det_model_dir ./inference/det_db # Address of detection inference model det_model_dir ./inference/det_db # Address of detection inference model
# cls config # cls config
...@@ -238,7 +239,7 @@ visualize 1 # Whether to visualize the results,when it is set as 1, The predic ...@@ -238,7 +239,7 @@ visualize 1 # Whether to visualize the results,when it is set as 1, The predic
The detection results will be shown on the screen, which is as follows. The detection results will be shown on the screen, which is as follows.
<div align="center"> <div align="center">
<img src="../imgs/cpp_infer_pred_12.png" width="600"> <img src="./imgs/cpp_infer_pred_12.png" width="600">
</div> </div>
......
...@@ -59,7 +59,8 @@ int main(int argc, char **argv) { ...@@ -59,7 +59,8 @@ int main(int argc, char **argv) {
config.gpu_mem, config.cpu_math_library_num_threads, config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.max_side_len, config.det_db_thresh, config.use_mkldnn, config.max_side_len, config.det_db_thresh,
config.det_db_box_thresh, config.det_db_unclip_ratio, config.det_db_box_thresh, config.det_db_unclip_ratio,
config.visualize, config.use_tensorrt, config.use_fp16); config.use_polygon_score, config.visualize,
config.use_tensorrt, config.use_fp16);
Classifier *cls = nullptr; Classifier *cls = nullptr;
if (config.use_angle_cls == true) { if (config.use_angle_cls == true) {
......
...@@ -109,9 +109,9 @@ void DBDetector::Run(cv::Mat &img, ...@@ -109,9 +109,9 @@ void DBDetector::Run(cv::Mat &img,
cv::Mat dilation_map; cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2)); cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele); cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map, boxes = post_processor_.BoxesFromBitmap(
this->det_db_box_thresh_, pred_map, dilation_map, this->det_db_box_thresh_,
this->det_db_unclip_ratio_); this->det_db_unclip_ratio_, this->use_polygon_score_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg); boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
......
...@@ -159,6 +159,53 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box, ...@@ -159,6 +159,53 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
return array; return array;
} }
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
std::vector<float> box_y;
for (int i = 0; i < contour.size(); ++i) {
box_x.push_back(contour[i].x);
box_y.push_back(contour[i].y);
}
int xmin =
clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int xmax =
clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int ymin =
clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
height - 1);
int ymax =
clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point* rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
delete []rook_point;
return score;
}
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array, float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
cv::Mat pred) { cv::Mat pred) {
auto array = box_array; auto array = box_array;
...@@ -197,10 +244,9 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array, ...@@ -197,10 +244,9 @@ float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
return score; return score;
} }
std::vector<std::vector<std::vector<int>>> std::vector<std::vector<std::vector<int>>> PostProcessor::BoxesFromBitmap(
PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, const cv::Mat pred, const cv::Mat bitmap, const float &box_thresh,
const float &box_thresh, const float &det_db_unclip_ratio, const bool &use_polygon_score) {
const float &det_db_unclip_ratio) {
const int min_size = 3; const int min_size = 3;
const int max_candidates = 1000; const int max_candidates = 1000;
...@@ -234,7 +280,12 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, ...@@ -234,7 +280,12 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
} }
float score; float score;
score = BoxScoreFast(array, pred); if (use_polygon_score)
/* compute using polygon*/
score = PolygonScoreAcc(contours[_i], pred);
else
score = BoxScoreFast(array, pred);
if (score < box_thresh) if (score < box_thresh)
continue; continue;
......
...@@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, ...@@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
int resize_h = int(float(h) * ratio); int resize_h = int(float(h) * ratio);
int resize_w = int(float(w) * ratio); int resize_w = int(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h; resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
else if (resize_h / 32 < 1 + 1e-5) resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
resize_h = 32;
else
resize_h = (resize_h / 32) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32) * 32;
if (!use_tensorrt) { if (!use_tensorrt) {
cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_h = float(resize_h) / float(h); ratio_h = float(resize_h) / float(h);
......
...@@ -10,6 +10,7 @@ max_side_len 960 ...@@ -10,6 +10,7 @@ max_side_len 960
det_db_thresh 0.3 det_db_thresh 0.3
det_db_box_thresh 0.5 det_db_box_thresh 0.5
det_db_unclip_ratio 1.6 det_db_unclip_ratio 1.6
use_polygon_score 1
det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/ det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
# cls config # cls config
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,8 @@ import paddlehub as hub ...@@ -14,6 +15,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_cls import TextClassifier from tools.infer.predict_cls import TextClassifier
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_cls.params import read_params
@moduleinfo( @moduleinfo(
...@@ -28,8 +31,7 @@ class OCRCls(hub.Module): ...@@ -28,8 +31,7 @@ class OCRCls(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_cls.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +50,20 @@ class OCRCls(hub.Module): ...@@ -48,6 +50,20 @@ class OCRCls(hub.Module):
self.text_classifier = TextClassifier(cfg) self.text_classifier = TextClassifier(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -7,6 +7,8 @@ import os ...@@ -7,6 +7,8 @@ import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
import cv2 import cv2
...@@ -15,6 +17,8 @@ import paddlehub as hub ...@@ -15,6 +17,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_det import TextDetector from tools.infer.predict_det import TextDetector
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_system.params import read_params
@moduleinfo( @moduleinfo(
...@@ -29,8 +33,7 @@ class OCRDet(hub.Module): ...@@ -29,8 +33,7 @@ class OCRDet(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_det.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -49,6 +52,20 @@ class OCRDet(hub.Module): ...@@ -49,6 +52,20 @@ class OCRDet(hub.Module):
self.text_detector = TextDetector(cfg) self.text_detector = TextDetector(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
# #EAST parmas # #EAST parmas
# cfg.det_east_score_thresh = 0.8 # cfg.det_east_score_thresh = 0.8
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.module.module import moduleinfo, runnable, serving from paddlehub.module.module import moduleinfo, runnable, serving
...@@ -14,6 +15,8 @@ import paddlehub as hub ...@@ -14,6 +15,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_rec import TextRecognizer from tools.infer.predict_rec import TextRecognizer
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_rec.params import read_params
@moduleinfo( @moduleinfo(
...@@ -28,8 +31,7 @@ class OCRRec(hub.Module): ...@@ -28,8 +31,7 @@ class OCRRec(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_rec.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -48,6 +50,20 @@ class OCRRec(hub.Module): ...@@ -48,6 +50,20 @@ class OCRRec(hub.Module):
self.text_recognizer = TextRecognizer(cfg) self.text_recognizer = TextRecognizer(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -6,6 +6,7 @@ from __future__ import print_function ...@@ -6,6 +6,7 @@ from __future__ import print_function
import os import os
import sys import sys
sys.path.insert(0, ".") sys.path.insert(0, ".")
import copy
import time import time
...@@ -17,6 +18,8 @@ import paddlehub as hub ...@@ -17,6 +18,8 @@ import paddlehub as hub
from tools.infer.utility import base64_to_cv2 from tools.infer.utility import base64_to_cv2
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from tools.infer.utility import parse_args
from deploy.hubserving.ocr_system.params import read_params
@moduleinfo( @moduleinfo(
...@@ -31,8 +34,7 @@ class OCRSystem(hub.Module): ...@@ -31,8 +34,7 @@ class OCRSystem(hub.Module):
""" """
initialize with the necessary elements initialize with the necessary elements
""" """
from ocr_system.params import read_params cfg = self.merge_configs()
cfg = read_params()
cfg.use_gpu = use_gpu cfg.use_gpu = use_gpu
if use_gpu: if use_gpu:
...@@ -51,6 +53,20 @@ class OCRSystem(hub.Module): ...@@ -51,6 +53,20 @@ class OCRSystem(hub.Module):
self.text_sys = TextSystem(cfg) self.text_sys = TextSystem(cfg)
def merge_configs(self, ):
# deafult cfg
backup_argv = copy.deepcopy(sys.argv)
sys.argv = sys.argv[:1]
cfg = parse_args()
update_cfg_map = vars(read_params())
for key in update_cfg_map:
cfg.__setattr__(key, update_cfg_map[key])
sys.argv = copy.deepcopy(backup_argv)
return cfg
def read_images(self, paths=[]): def read_images(self, paths=[]):
images = [] images = []
for img_path in paths: for img_path in paths:
......
...@@ -22,6 +22,7 @@ def read_params(): ...@@ -22,6 +22,7 @@ def read_params():
cfg.det_db_box_thresh = 0.5 cfg.det_db_box_thresh = 0.5
cfg.det_db_unclip_ratio = 1.6 cfg.det_db_unclip_ratio = 1.6
cfg.use_dilation = False cfg.use_dilation = False
cfg.det_db_score_mode = "fast"
#EAST parmas #EAST parmas
cfg.det_east_score_thresh = 0.8 cfg.det_east_score_thresh = 0.8
......
ARM_ABI = arm8
export ARM_ABI
include ../Makefile.def
LITE_ROOT=../../../
THIRD_PARTY_DIR=${LITE_ROOT}/third_party
OPENCV_VERSION=opencv4.1.0
OPENCV_LIBS = ../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgcodecs.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_imgproc.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/libs/libopencv_core.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtegra_hal.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjpeg-turbo.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibwebp.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibpng.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibjasper.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/liblibtiff.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libIlmImf.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libtbb.a \
../../../third_party/${OPENCV_VERSION}/arm64-v8a/3rdparty/libs/libcpufeatures.a
OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include
CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include
CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS)
###############################################################
# How to use one of static libaray: #
# `libpaddle_api_full_bundled.a` #
# `libpaddle_api_light_bundled.a` #
###############################################################
# Note: default use lite's shared library. #
###############################################################
# 1. Comment above line using `libpaddle_light_api_shared.so`
# 2. Undo comment below line using `libpaddle_api_light_bundled.a`
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
ocr_db_crnn.o: ocr_db_crnn.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
crnn_process.o: fetch_opencv crnn_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
cls_process.o: fetch_opencv cls_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
clipper.o: fetch_clipper
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o clipper.o -c clipper.cpp
fetch_clipper:
@test -e clipper.hpp || \
( echo "Fetch clipper " && \
wget -c https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/Clipper/clipper.hpp)
@ test -e clipper.cpp || \
wget -c https://paddle-inference-dist.cdn.bcebos.com/PaddleLite/Clipper/clipper.cpp
fetch_opencv:
@ test -d ${THIRD_PARTY_DIR} || mkdir ${THIRD_PARTY_DIR}
@ test -e ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz || \
(echo "fetch opencv libs" && \
wget -P ${THIRD_PARTY_DIR} https://paddle-inference-dist.bj.bcebos.com/${OPENCV_VERSION}.tar.gz)
@ test -d ${THIRD_PARTY_DIR}/${OPENCV_VERSION} || \
tar -zxvf ${THIRD_PARTY_DIR}/${OPENCV_VERSION}.tar.gz -C ${THIRD_PARTY_DIR}
.PHONY: clean
clean:
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
rm -f ocr_db_crnn
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cls_process.h" //NOLINT
#include <algorithm>
#include <memory>
#include <string>
const std::vector<int> rec_image_shape{3, 48, 192};
cv::Mat ClsResizeImg(cv::Mat img) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
return resize_img;
}
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "math.h" //NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
cv::Mat ClsResizeImg(cv::Mat img);
\ No newline at end of file
max_side_len 960
det_db_thresh 0.3
det_db_box_thresh 0.5
det_db_unclip_ratio 1.6
det_db_use_dilate 0
det_use_polygon_score 1
use_direction_classify 1
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "crnn_process.h" //NOLINT
#include <algorithm>
#include <memory>
#include <string>
const std::vector<int> rec_image_shape{3, 32, 320};
cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgW = rec_image_shape[2];
imgH = rec_image_shape[1];
imgW = int(32 * wh_ratio);
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = static_cast<int>(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
return resize_img;
}
std::vector<std::string> ReadDict(std::string path) {
std::ifstream in(path);
std::string filename;
std::string line;
std::vector<std::string> m_vec;
if (in) {
while (getline(in, line)) {
m_vec.push_back(line);
}
} else {
std::cout << "no such file" << std::endl;
}
return m_vec;
}
cv::Mat GetRotateCropImage(cv::Mat srcimage,
std::vector<std::vector<int>> box) {
cv::Mat image;
srcimage.copyTo(image);
std::vector<std::vector<int>> points = box;
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
int left = int(*std::min_element(x_collect, x_collect + 4));
int right = int(*std::max_element(x_collect, x_collect + 4));
int top = int(*std::min_element(y_collect, y_collect + 4));
int bottom = int(*std::max_element(y_collect, y_collect + 4));
cv::Mat img_crop;
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
for (int i = 0; i < points.size(); i++) {
points[i][0] -= left;
points[i][1] -= top;
}
int img_crop_width =
static_cast<int>(sqrt(pow(points[0][0] - points[1][0], 2) +
pow(points[0][1] - points[1][1], 2)));
int img_crop_height =
static_cast<int>(sqrt(pow(points[0][0] - points[3][0], 2) +
pow(points[0][1] - points[3][1], 2)));
cv::Point2f pts_std[4];
pts_std[0] = cv::Point2f(0., 0.);
pts_std[1] = cv::Point2f(img_crop_width, 0.);
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
pts_std[3] = cv::Point2f(0.f, img_crop_height);
cv::Point2f pointsf[4];
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
cv::Mat dst_img;
cv::warpPerspective(img_crop, dst_img, M,
cv::Size(img_crop_width, img_crop_height),
cv::BORDER_REPLICATE);
const float ratio = 1.5;
if (static_cast<float>(dst_img.rows) >=
static_cast<float>(dst_img.cols) * ratio) {
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
cv::transpose(dst_img, srcCopy);
cv::flip(srcCopy, srcCopy, 0);
return srcCopy;
} else {
return dst_img;
}
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "math.h" //NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio);
std::vector<std::string> ReadDict(std::string path);
cv::Mat GetRotateCropImage(cv::Mat srcimage, std::vector<std::vector<int>> box);
template <class ForwardIterator>
inline size_t Argmax(ForwardIterator first, ForwardIterator last) {
return std::distance(first, std::max_element(first, last));
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "db_post_process.h" // NOLINT
#include <algorithm>
#include <utility>
void GetContourArea(std::vector<std::vector<float>> box, float unclip_ratio,
float &distance) {
int pts_num = 4;
float area = 0.0f;
float dist = 0.0f;
for (int i = 0; i < pts_num; i++) {
area += box[i][0] * box[(i + 1) % pts_num][1] -
box[i][1] * box[(i + 1) % pts_num][0];
dist += sqrtf((box[i][0] - box[(i + 1) % pts_num][0]) *
(box[i][0] - box[(i + 1) % pts_num][0]) +
(box[i][1] - box[(i + 1) % pts_num][1]) *
(box[i][1] - box[(i + 1) % pts_num][1]));
}
area = fabs(float(area / 2.0));
distance = area * unclip_ratio / dist;
}
cv::RotatedRect Unclip(std::vector<std::vector<float>> box,
float unclip_ratio) {
float distance = 1.0;
GetContourArea(box, unclip_ratio, distance);
ClipperLib::ClipperOffset offset;
ClipperLib::Path p;
p << ClipperLib::IntPoint(static_cast<int>(box[0][0]),
static_cast<int>(box[0][1]))
<< ClipperLib::IntPoint(static_cast<int>(box[1][0]),
static_cast<int>(box[1][1]))
<< ClipperLib::IntPoint(static_cast<int>(box[2][0]),
static_cast<int>(box[2][1]))
<< ClipperLib::IntPoint(static_cast<int>(box[3][0]),
static_cast<int>(box[3][1]));
offset.AddPath(p, ClipperLib::jtRound, ClipperLib::etClosedPolygon);
ClipperLib::Paths soln;
offset.Execute(soln, distance);
std::vector<cv::Point2f> points;
for (int j = 0; j < soln.size(); j++) {
for (int i = 0; i < soln[soln.size() - 1].size(); i++) {
points.emplace_back(soln[j][i].X, soln[j][i].Y);
}
}
cv::RotatedRect res = cv::minAreaRect(points);
return res;
}
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat) {
std::vector<std::vector<float>> img_vec;
std::vector<float> tmp;
for (int i = 0; i < mat.rows; ++i) {
tmp.clear();
for (int j = 0; j < mat.cols; ++j) {
tmp.push_back(mat.at<float>(i, j));
}
img_vec.push_back(tmp);
}
return img_vec;
}
bool XsortFp32(std::vector<float> a, std::vector<float> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
bool XsortInt(std::vector<int> a, std::vector<int> b) {
if (a[0] != b[0])
return a[0] < b[0];
return false;
}
std::vector<std::vector<int>>
OrderPointsClockwise(std::vector<std::vector<int>> pts) {
std::vector<std::vector<int>> box = pts;
std::sort(box.begin(), box.end(), XsortInt);
std::vector<std::vector<int>> leftmost = {box[0], box[1]};
std::vector<std::vector<int>> rightmost = {box[2], box[3]};
if (leftmost[0][1] > leftmost[1][1])
std::swap(leftmost[0], leftmost[1]);
if (rightmost[0][1] > rightmost[1][1])
std::swap(rightmost[0], rightmost[1]);
std::vector<std::vector<int>> rect = {leftmost[0], rightmost[0], rightmost[1],
leftmost[1]};
return rect;
}
std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box, float &ssid) {
ssid = std::min(box.size.width, box.size.height);
cv::Mat points;
cv::boxPoints(box, points);
auto array = Mat2Vector(points);
std::sort(array.begin(), array.end(), XsortFp32);
std::vector<float> idx1 = array[0], idx2 = array[1], idx3 = array[2],
idx4 = array[3];
if (array[3][1] <= array[2][1]) {
idx2 = array[3];
idx3 = array[2];
} else {
idx2 = array[2];
idx3 = array[3];
}
if (array[1][1] <= array[0][1]) {
idx1 = array[1];
idx4 = array[0];
} else {
idx1 = array[0];
idx4 = array[1];
}
array[0] = idx1;
array[1] = idx2;
array[2] = idx3;
array[3] = idx4;
return array;
}
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred) {
auto array = box_array;
int width = pred.cols;
int height = pred.rows;
float box_x[4] = {array[0][0], array[1][0], array[2][0], array[3][0]};
float box_y[4] = {array[0][1], array[1][1], array[2][1], array[3][1]};
int xmin = clamp(
static_cast<int>(std::floorf(*(std::min_element(box_x, box_x + 4)))), 0,
width - 1);
int xmax =
clamp(static_cast<int>(std::ceilf(*(std::max_element(box_x, box_x + 4)))),
0, width - 1);
int ymin = clamp(
static_cast<int>(std::floorf(*(std::min_element(box_y, box_y + 4)))), 0,
height - 1);
int ymax =
clamp(static_cast<int>(std::ceilf(*(std::max_element(box_y, box_y + 4)))),
0, height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point root_point[4];
root_point[0] = cv::Point(static_cast<int>(array[0][0]) - xmin,
static_cast<int>(array[0][1]) - ymin);
root_point[1] = cv::Point(static_cast<int>(array[1][0]) - xmin,
static_cast<int>(array[1][1]) - ymin);
root_point[2] = cv::Point(static_cast<int>(array[2][0]) - xmin,
static_cast<int>(array[2][1]) - ymin);
root_point[3] = cv::Point(static_cast<int>(array[3][0]) - xmin,
static_cast<int>(array[3][1]) - ymin);
const cv::Point *ppt[1] = {root_point};
int npt[] = {4};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
auto score = cv::mean(croppedImg, mask)[0];
return score;
}
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred) {
int width = pred.cols;
int height = pred.rows;
std::vector<float> box_x;
std::vector<float> box_y;
for (int i = 0; i < contour.size(); ++i) {
box_x.push_back(contour[i].x);
box_y.push_back(contour[i].y);
}
int xmin =
clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int xmax =
clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0,
width - 1);
int ymin =
clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0,
height - 1);
int ymax =
clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0,
height - 1);
cv::Mat mask;
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
cv::Point *rook_point = new cv::Point[contour.size()];
for (int i = 0; i < contour.size(); ++i) {
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
}
const cv::Point *ppt[1] = {rook_point};
int npt[] = {int(contour.size())};
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
cv::Mat croppedImg;
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1))
.copyTo(croppedImg);
float score = cv::mean(croppedImg, mask)[0];
delete[] rook_point;
return score;
}
std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::map<std::string, double> Config) {
const int min_size = 3;
const int max_candidates = 1000;
const float box_thresh = static_cast<float>(Config["det_db_box_thresh"]);
const float unclip_ratio = static_cast<float>(Config["det_db_unclip_ratio"]);
const int det_use_polygon_score = int(Config["det_use_polygon_score"]);
int width = bitmap.cols;
int height = bitmap.rows;
std::vector<std::vector<cv::Point>> contours;
std::vector<cv::Vec4i> hierarchy;
cv::findContours(bitmap, contours, hierarchy, cv::RETR_LIST,
cv::CHAIN_APPROX_SIMPLE);
int num_contours =
contours.size() >= max_candidates ? max_candidates : contours.size();
std::vector<std::vector<std::vector<int>>> boxes;
for (int i = 0; i < num_contours; i++) {
float ssid;
if (contours[i].size() <= 2)
continue;
cv::RotatedRect box = cv::minAreaRect(contours[i]);
auto array = GetMiniBoxes(box, ssid);
auto box_for_unclip = array;
// end get_mini_box
if (ssid < min_size) {
continue;
}
float score;
if (det_use_polygon_score) {
score = PolygonScoreAcc(contours[i], pred);
} else {
score = BoxScoreFast(array, pred);
}
// end box_score_fast
if (score < box_thresh)
continue;
// start for unclip
cv::RotatedRect points = Unclip(box_for_unclip, unclip_ratio);
if (points.size.height < 1.001 && points.size.width < 1.001)
continue;
// end for unclip
cv::RotatedRect clipbox = points;
auto cliparray = GetMiniBoxes(clipbox, ssid);
if (ssid < min_size + 2)
continue;
int dest_width = pred.cols;
int dest_height = pred.rows;
std::vector<std::vector<int>> intcliparray;
for (int num_pt = 0; num_pt < 4; num_pt++) {
std::vector<int> a{
static_cast<int>(clamp(
roundf(cliparray[num_pt][0] / float(width) * float(dest_width)),
float(0), float(dest_width))),
static_cast<int>(clamp(
roundf(cliparray[num_pt][1] / float(height) * float(dest_height)),
float(0), float(dest_height)))};
intcliparray.push_back(a);
}
boxes.push_back(intcliparray);
} // end for
return boxes;
}
std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w, cv::Mat srcimg) {
int oriimg_h = srcimg.rows;
int oriimg_w = srcimg.cols;
std::vector<std::vector<std::vector<int>>> root_points;
for (int n = 0; n < static_cast<int>(boxes.size()); n++) {
boxes[n] = OrderPointsClockwise(boxes[n]);
for (int m = 0; m < static_cast<int>(boxes[0].size()); m++) {
boxes[n][m][0] /= ratio_w;
boxes[n][m][1] /= ratio_h;
boxes[n][m][0] =
static_cast<int>(std::min(std::max(boxes[n][m][0], 0), oriimg_w - 1));
boxes[n][m][1] =
static_cast<int>(std::min(std::max(boxes[n][m][1], 0), oriimg_h - 1));
}
}
for (int n = 0; n < boxes.size(); n++) {
int rect_width, rect_height;
rect_width =
static_cast<int>(sqrt(pow(boxes[n][0][0] - boxes[n][1][0], 2) +
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height =
static_cast<int>(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 4 || rect_height <= 4)
continue;
root_points.push_back(boxes[n]);
}
return root_points;
}
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <iostream>
#include <map>
#include <vector>
#include "clipper.hpp"
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
template <class T> T clamp(T x, T min, T max) {
if (x > max)
return max;
if (x < min)
return min;
return x;
}
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
void GetContourArea(std::vector<std::vector<float>> box, float unclip_ratio,
float &distance);
cv::RotatedRect Unclip(std::vector<std::vector<float>> box, float unclip_ratio);
std::vector<std::vector<float>> Mat2Vector(cv::Mat mat);
bool XsortFp32(std::vector<float> a, std::vector<float> b);
bool XsortInt(std::vector<int> a, std::vector<int> b);
std::vector<std::vector<int>>
OrderPointsClockwise(std::vector<std::vector<int>> pts);
std::vector<std::vector<float>> GetMiniBoxes(cv::RotatedRect box, float &ssid);
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
std::vector<std::vector<std::vector<int>>>
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
std::map<std::string, double> Config);
std::vector<std::vector<std::vector<int>>>
FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
float ratio_w, cv::Mat srcimg);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle_api.h" // NOLINT
#include <chrono>
#include "cls_process.h"
#include "crnn_process.h"
#include "db_post_process.h"
using namespace paddle::lite_api; // NOLINT
using namespace std;
// fill tensor with mean and scale and trans layout: nhwc -> nchw, neon speed up
void NeonMeanScale(const float *din, float *dout, int size,
const std::vector<float> mean,
const std::vector<float> scale) {
if (mean.size() != 3 || scale.size() != 3) {
std::cerr << "[ERROR] mean or scale size must equal to 3\n";
exit(1);
}
float32x4_t vmean0 = vdupq_n_f32(mean[0]);
float32x4_t vmean1 = vdupq_n_f32(mean[1]);
float32x4_t vmean2 = vdupq_n_f32(mean[2]);
float32x4_t vscale0 = vdupq_n_f32(scale[0]);
float32x4_t vscale1 = vdupq_n_f32(scale[1]);
float32x4_t vscale2 = vdupq_n_f32(scale[2]);
float *dout_c0 = dout;
float *dout_c1 = dout + size;
float *dout_c2 = dout + size * 2;
int i = 0;
for (; i < size - 3; i += 4) {
float32x4x3_t vin3 = vld3q_f32(din);
float32x4_t vsub0 = vsubq_f32(vin3.val[0], vmean0);
float32x4_t vsub1 = vsubq_f32(vin3.val[1], vmean1);
float32x4_t vsub2 = vsubq_f32(vin3.val[2], vmean2);
float32x4_t vs0 = vmulq_f32(vsub0, vscale0);
float32x4_t vs1 = vmulq_f32(vsub1, vscale1);
float32x4_t vs2 = vmulq_f32(vsub2, vscale2);
vst1q_f32(dout_c0, vs0);
vst1q_f32(dout_c1, vs1);
vst1q_f32(dout_c2, vs2);
din += 12;
dout_c0 += 4;
dout_c1 += 4;
dout_c2 += 4;
}
for (; i < size; i++) {
*(dout_c0++) = (*(din++) - mean[0]) * scale[0];
*(dout_c1++) = (*(din++) - mean[1]) * scale[1];
*(dout_c2++) = (*(din++) - mean[2]) * scale[2];
}
}
// resize image to a size multiple of 32 which is required by the network
cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
std::vector<float> &ratio_hw) {
int w = img.cols;
int h = img.rows;
float ratio = 1.f;
int max_wh = w >= h ? w : h;
if (max_wh > max_size_len) {
if (h > w) {
ratio = static_cast<float>(max_size_len) / static_cast<float>(h);
} else {
ratio = static_cast<float>(max_size_len) / static_cast<float>(w);
}
}
int resize_h = static_cast<int>(float(h) * ratio);
int resize_w = static_cast<int>(float(w) * ratio);
if (resize_h % 32 == 0)
resize_h = resize_h;
else if (resize_h / 32 < 1 + 1e-5)
resize_h = 32;
else
resize_h = (resize_h / 32 - 1) * 32;
if (resize_w % 32 == 0)
resize_w = resize_w;
else if (resize_w / 32 < 1 + 1e-5)
resize_w = 32;
else
resize_w = (resize_w / 32 - 1) * 32;
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
ratio_hw.push_back(static_cast<float>(resize_h) / static_cast<float>(h));
ratio_hw.push_back(static_cast<float>(resize_w) / static_cast<float>(w));
return resize_img;
}
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
const float thresh = 0.9) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
img.copyTo(crop_img);
cv::Mat resize_img;
int index = 0;
float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
resize_img = ClsResizeImg(crop_img);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
auto *data0 = input_tensor0->mutable_data<float>();
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
// Run CLS predictor
predictor_cls->Run();
// Get output and run postprocess
std::unique_ptr<const Tensor> softmax_out(
std::move(predictor_cls->GetOutput(0)));
auto *softmax_scores = softmax_out->mutable_data<float>();
auto softmax_out_shape = softmax_out->shape();
float score = 0;
int label = 0;
for (int i = 0; i < softmax_out_shape[1]; i++) {
if (softmax_scores[i] > score) {
score = softmax_scores[i];
label = i;
}
}
if (label % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1);
}
return srcimg;
}
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
std::shared_ptr<PaddlePredictor> predictor_crnn,
std::vector<std::string> &rec_text,
std::vector<float> &rec_text_score,
std::vector<std::string> charactor_dict,
std::shared_ptr<PaddlePredictor> predictor_cls,
int use_direction_classify) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
cv::Mat resize_img;
int index = 0;
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(srcimg, boxes[i]);
if (use_direction_classify >= 1) {
crop_img = RunClsModel(crop_img, predictor_cls);
}
float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
resize_img = CrnnResizeImg(crop_img, wh_ratio);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
std::unique_ptr<Tensor> input_tensor0(
std::move(predictor_crnn->GetInput(0)));
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
auto *data0 = input_tensor0->mutable_data<float>();
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
//// Run CRNN predictor
predictor_crnn->Run();
// Get output and run postprocess
std::unique_ptr<const Tensor> output_tensor0(
std::move(predictor_crnn->GetOutput(0)));
auto *predict_batch = output_tensor0->data<float>();
auto predict_shape = output_tensor0->shape();
// ctc decode
std::string str_res;
int argmax_idx;
int last_index = 0;
float score = 0.f;
int count = 0;
float max_value = 0.0f;
for (int n = 0; n < predict_shape[1]; n++) {
argmax_idx = int(Argmax(&predict_batch[n * predict_shape[2]],
&predict_batch[(n + 1) * predict_shape[2]]));
max_value =
float(*std::max_element(&predict_batch[n * predict_shape[2]],
&predict_batch[(n + 1) * predict_shape[2]]));
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
score += max_value;
count += 1;
str_res += charactor_dict[argmax_idx];
}
last_index = argmax_idx;
}
score /= count;
rec_text.push_back(str_res);
rec_text_score.push_back(score);
}
}
std::vector<std::vector<std::vector<int>>>
RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
std::map<std::string, double> Config) {
// Read img
int max_side_len = int(Config["max_side_len"]);
int det_db_use_dilate = int(Config["det_db_use_dilate"]);
cv::Mat srcimg;
img.copyTo(srcimg);
std::vector<float> ratio_hw;
img = DetResizeImg(img, max_side_len, ratio_hw);
cv::Mat img_fp;
img.convertTo(img_fp, CV_32FC3, 1.0 / 255.f);
// Prepare input data from image
std::unique_ptr<Tensor> input_tensor0(std::move(predictor->GetInput(0)));
input_tensor0->Resize({1, 3, img_fp.rows, img_fp.cols});
auto *data0 = input_tensor0->mutable_data<float>();
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> scale = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
const float *dimg = reinterpret_cast<const float *>(img_fp.data);
NeonMeanScale(dimg, data0, img_fp.rows * img_fp.cols, mean, scale);
// Run predictor
predictor->Run();
// Get output and post process
std::unique_ptr<const Tensor> output_tensor(
std::move(predictor->GetOutput(0)));
auto *outptr = output_tensor->data<float>();
auto shape_out = output_tensor->shape();
// Save output
float pred[shape_out[2] * shape_out[3]];
unsigned char cbuf[shape_out[2] * shape_out[3]];
for (int i = 0; i < int(shape_out[2] * shape_out[3]); i++) {
pred[i] = static_cast<float>(outptr[i]);
cbuf[i] = static_cast<unsigned char>((outptr[i]) * 255);
}
cv::Mat cbuf_map(shape_out[2], shape_out[3], CV_8UC1,
reinterpret_cast<unsigned char *>(cbuf));
cv::Mat pred_map(shape_out[2], shape_out[3], CV_32F,
reinterpret_cast<float *>(pred));
const double threshold = double(Config["det_db_thresh"]) * 255;
const double max_value = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, max_value, cv::THRESH_BINARY);
if (det_db_use_dilate == 1) {
cv::Mat dilation_map;
cv::Mat dila_ele =
cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
cv::dilate(bit_map, dilation_map, dila_ele);
bit_map = dilation_map;
}
auto boxes = BoxesFromBitmap(pred_map, bit_map, Config);
std::vector<std::vector<std::vector<int>>> filter_boxes =
FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);
return filter_boxes;
}
std::shared_ptr<PaddlePredictor> loadModel(std::string model_file) {
MobileConfig config;
config.set_model_from_file(model_file);
std::shared_ptr<PaddlePredictor> predictor =
CreatePaddlePredictor<MobileConfig>(config);
return predictor;
}
cv::Mat Visualization(cv::Mat srcimg,
std::vector<std::vector<std::vector<int>>> boxes) {
cv::Point rook_points[boxes.size()][4];
for (int n = 0; n < boxes.size(); n++) {
for (int m = 0; m < boxes[0].size(); m++) {
rook_points[n][m] = cv::Point(static_cast<int>(boxes[n][m][0]),
static_cast<int>(boxes[n][m][1]));
}
}
cv::Mat img_vis;
srcimg.copyTo(img_vis);
for (int n = 0; n < boxes.size(); n++) {
const cv::Point *ppt[1] = {rook_points[n]};
int npt[] = {4};
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
}
cv::imwrite("./vis.jpg", img_vis);
std::cout << "The detection visualized image saved in ./vis.jpg" << std::endl;
return img_vis;
}
std::vector<std::string> split(const std::string &str,
const std::string &delim) {
std::vector<std::string> res;
if ("" == str)
return res;
char *strs = new char[str.length() + 1];
std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1];
std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d);
while (p) {
string s = p;
res.push_back(s);
p = std::strtok(NULL, d);
}
return res;
}
std::map<std::string, double> LoadConfigTxt(std::string config_path) {
auto config = ReadDict(config_path);
std::map<std::string, double> dict;
for (int i = 0; i < config.size(); i++) {
std::vector<std::string> res = split(config[i], " ");
dict[res[0]] = stod(res[1]);
}
return dict;
}
int main(int argc, char **argv) {
if (argc < 5) {
std::cerr << "[ERROR] usage: " << argv[0]
<< " det_model_file cls_model_file rec_model_file image_path "
"charactor_dict\n";
exit(1);
}
std::string det_model_file = argv[1];
std::string rec_model_file = argv[2];
std::string cls_model_file = argv[3];
std::string img_path = argv[4];
std::string dict_path = argv[5];
//// load config from txt file
auto Config = LoadConfigTxt("./config.txt");
int use_direction_classify = int(Config["use_direction_classify"]);
auto start = std::chrono::system_clock::now();
auto det_predictor = loadModel(det_model_file);
auto rec_predictor = loadModel(rec_model_file);
auto cls_predictor = loadModel(cls_model_file);
auto charactor_dict = ReadDict(dict_path);
charactor_dict.insert(charactor_dict.begin(), "#"); // blank char for ctc
charactor_dict.push_back(" ");
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
auto boxes = RunDetModel(det_predictor, srcimg, Config);
std::vector<std::string> rec_text;
std::vector<float> rec_text_score;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, use_direction_classify);
auto end = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
//// visualization
auto img_vis = Visualization(srcimg, boxes);
//// print recognized text
for (int i = 0; i < rec_text.size(); i++) {
std::cout << i << "\t" << rec_text[i] << "\t" << rec_text_score[i]
<< std::endl;
}
std::cout << "花费了"
<< double(duration.count()) *
std::chrono::microseconds::period::num /
std::chrono::microseconds::period::den
<< "秒" << std::endl;
return 0;
}
\ No newline at end of file
#!/bin/bash
mkdir -p $1/demo/cxx/ocr/debug/
cp ../../ppocr/utils/ppocr_keys_v1.txt $1/demo/cxx/ocr/debug/
cp -r ./* $1/demo/cxx/ocr/
cp ./config.txt $1/demo/cxx/ocr/debug/
cp ../../doc/imgs/11.jpg $1/demo/cxx/ocr/debug/
echo "Prepare Done"
# 端侧部署
本教程将介绍基于[Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) 在移动端部署PaddleOCR超轻量中文检测、识别模型的详细步骤。
Paddle Lite是飞桨轻量化推理引擎,为手机、IOT端提供高效推理能力,并广泛整合跨平台硬件,为端侧部署及应用落地问题提供轻量化的部署方案。
## 1. 准备环境
### 运行准备
- 电脑(编译Paddle Lite)
- 安卓手机(armv7或armv8)
### 1.1 准备交叉编译环境
交叉编译环境用于编译 Paddle Lite 和 PaddleOCR 的C++ demo。
支持多种开发环境,不同开发环境的编译流程请参考对应文档。
1. [Docker](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#docker)
2. [Linux](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#linux)
3. [MAC OS](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#mac-os)
### 1.2 准备预测库
预测库有两种获取方式:
- 1. 直接下载,预测库下载链接如下:
| 平台 | 预测库下载链接 |
|---|---|
|Android|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.android.armv7.gcc.c++_shared.with_extra.with_cv.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.android.armv8.gcc.c++_shared.with_extra.with_cv.tar.gz)|
|IOS|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.ios.armv7.with_cv.with_extra.with_log.tiny_publish.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.ios.armv8.with_cv.with_extra.with_log.tiny_publish.tar.gz)|
注:1. 上述预测库为PaddleLite 2.9分支编译得到,有关PaddleLite 2.9 详细信息可参考 [链接](https://github.com/PaddlePaddle/Paddle-Lite/releases/tag/v2.9) 。
- 2. [推荐]编译Paddle-Lite得到预测库,Paddle-Lite的编译方式如下:
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
# 切换到Paddle-Lite release/v2.9 稳定分支
git checkout release/v2.9
./lite/tools/build_android.sh --arch=armv8 --with_cv=ON --with_extra=ON
```
注意:编译Paddle-Lite获得预测库时,需要打开`--with_cv=ON --with_extra=ON`两个选项,`--arch`表示`arm`版本,这里指定为armv8,
更多编译命令
介绍请参考 [链接](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_andriod.html)
直接下载预测库并解压后,可以得到`inference_lite_lib.android.armv8/`文件夹,通过编译Paddle-Lite得到的预测库位于
`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/`文件夹下。
预测库的文件目录如下:
```
inference_lite_lib.android.armv8/
|-- cxx C++ 预测库和头文件
| |-- include C++ 头文件
| | |-- paddle_api.h
| | |-- paddle_image_preprocess.h
| | |-- paddle_lite_factory_helper.h
| | |-- paddle_place.h
| | |-- paddle_use_kernels.h
| | |-- paddle_use_ops.h
| | `-- paddle_use_passes.h
| `-- lib C++预测库
| |-- libpaddle_api_light_bundled.a C++静态库
| `-- libpaddle_light_api_shared.so C++动态库
|-- java Java预测库
| |-- jar
| | `-- PaddlePredictor.jar
| |-- so
| | `-- libpaddle_lite_jni.so
| `-- src
|-- demo C++和Java示例代码
| |-- cxx C++ 预测库demo
| `-- java Java 预测库demo
```
## 2 开始运行
### 2.1 模型优化
Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括量化、子图融合、混合调度、Kernel优选等方法,使用Paddle-lite的opt工具可以自动
对inference模型进行优化,优化后的模型更轻量,模型运行速度更快。
如果已经准备好了 `.nb` 结尾的模型文件,可以跳过此步骤。
下述表格中也提供了一系列中文移动端模型:
|模型版本|模型简介|模型大小|检测模型|文本方向分类模型|识别模型|Paddle-Lite版本|
|---|---|---|---|---|---|---|
|V2.0|超轻量中文OCR 移动端模型|7.8M|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_opt.nb)|v2.9|
|V2.0(slim)|超轻量中文OCR 移动端模型|3.3M|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_slim_opt.nb)|v2.9|
如果直接使用上述表格中的模型进行部署,可略过下述步骤,直接阅读 [2.2节](#2.2与手机联调)
如果要部署的模型不在上述表格中,则需要按照如下步骤获得优化后的模型。
模型优化需要Paddle-Lite的opt可执行文件,可以通过编译Paddle-Lite源码获得,编译步骤如下:
```
# 如果准备环境时已经clone了Paddle-Lite,则不用重新clone Paddle-Lite
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
git checkout release/v2.9
# 启动编译
./lite/tools/build.sh build_optimize_tool
```
编译完成后,opt文件位于`build.opt/lite/api/`下,可通过如下方式查看opt的运行选项和使用方式;
```
cd build.opt/lite/api/
./opt
```
|选项|说明|
|---|---|
|--model_dir|待优化的PaddlePaddle模型(非combined形式)的路径|
|--model_file|待优化的PaddlePaddle模型(combined形式)的网络结构文件路径|
|--param_file|待优化的PaddlePaddle模型(combined形式)的权重文件路径|
|--optimize_out_type|输出模型类型,目前支持两种类型:protobuf和naive_buffer,其中naive_buffer是一种更轻量级的序列化/反序列化实现。若您需要在mobile端执行模型预测,请将此选项设置为naive_buffer。默认为protobuf|
|--optimize_out|优化模型的输出路径|
|--valid_targets|指定模型可执行的backend,默认为arm。目前可支持x86、arm、opencl、npu、xpu,可以同时指定多个backend(以空格分隔),Model Optimize Tool将会自动选择最佳方式。如果需要支持华为NPU(Kirin 810/990 Soc搭载的达芬奇架构NPU),应当设置为npu, arm|
|--record_tailoring_info|当使用 根据模型裁剪库文件 功能时,则设置该选项为true,以记录优化后模型含有的kernel和OP信息,默认为false|
`--model_dir`适用于待优化的模型是非combined方式,PaddleOCR的inference模型是combined方式,即模型结构和模型参数使用单独一个文件存储。
下面以PaddleOCR的超轻量中文模型为例,介绍使用编译好的opt文件完成inference模型到Paddle-Lite优化模型的转换。
```
# 【推荐】 下载PaddleOCR V2.0版本的中英文 inference模型
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_slim_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_slim_nfer.tar && tar xf ch_ppocr_mobile_v2.0_rec_slim_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_cls_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_slim_infer.tar
# 转换V2.0检测模型
./opt --model_file=./ch_ppocr_mobile_v2.0_det_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_det_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
# 转换V2.0识别模型
./opt --model_file=./ch_ppocr_mobile_v2.0_rec_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_rec_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
# 转换V2.0方向分类器模型
./opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
```
转换成功后,inference模型目录下会多出`.nb`结尾的文件,即是转换成功的模型文件。
注意:使用paddle-lite部署时,需要使用opt工具优化后的模型。 opt 工具的输入模型是paddle保存的inference模型
<a name="2.2与手机联调"></a>
### 2.2 与手机联调
首先需要进行一些准备工作。
1. 准备一台arm8的安卓手机,如果编译的预测库和opt文件是armv7,则需要arm7的手机,并修改Makefile中`ARM_ABI = arm7`
2. 打开手机的USB调试选项,选择文件传输模式,连接电脑。
3. 电脑上安装adb工具,用于调试。 adb安装方式如下:
3.1. MAC电脑安装ADB:
```
brew cask install android-platform-tools
```
3.2. Linux安装ADB
```
sudo apt update
sudo apt install -y wget adb
```
3.3. Window安装ADB
win上安装需要去谷歌的安卓平台下载adb软件包进行安装:[链接](https://developer.android.com/studio)
打开终端,手机连接电脑,在终端中输入
```
adb devices
```
如果有device输出,则表示安装成功。
```
List of devices attached
744be294 device
```
4. 准备优化后的模型、预测库文件、测试图像和使用的字典文件。
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR/deploy/lite/
# 运行prepare.sh,准备预测库文件、测试图像和使用的字典文件,并放置在预测库中的demo/cxx/ocr文件夹下
sh prepare.sh /{lite prediction library path}/inference_lite_lib.android.armv8
# 进入OCR demo的工作目录
cd /{lite prediction library path}/inference_lite_lib.android.armv8/
cd demo/cxx/ocr/
# 将C++预测动态库so文件复制到debug文件夹中
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./debug/
```
准备测试图像,以`PaddleOCR/doc/imgs/11.jpg`为例,将测试的图像复制到`demo/cxx/ocr/debug/`文件夹下。
准备lite opt工具优化后的模型文件,比如使用`ch_ppocr_mobile_v2.0_det_slim_opt.nb,ch_ppocr_mobile_v2.0_rec_slim_opt.nb, ch_ppocr_mobile_v2.0_cls_slim_opt.nb`,模型文件放置在`demo/cxx/ocr/debug/`文件夹下。
执行完成后,ocr文件夹下将有如下文件格式:
```
demo/cxx/ocr/
|-- debug/
| |--ch_ppocr_mobile_v2.0_det_slim_opt.nb 优化后的检测模型文件
| |--ch_ppocr_mobile_v2.0_rec_slim_opt.nb 优化后的识别模型文件
| |--ch_ppocr_mobile_v2.0_cls_slim_opt.nb 优化后的文字方向分类器模型文件
| |--11.jpg 待测试图像
| |--ppocr_keys_v1.txt 中文字典文件
| |--libpaddle_light_api_shared.so C++预测库文件
| |--config.txt 超参数配置
|-- config.txt 超参数配置
|-- cls_process.cc 方向分类器的预处理和后处理文件
|-- cls_process.h
|-- crnn_process.cc 识别模型CRNN的预处理和后处理文件
|-- crnn_process.h
|-- db_post_process.cc 检测模型DB的后处理文件
|-- db_post_process.h
|-- Makefile 编译文件
|-- ocr_db_crnn.cc C++预测源文件
```
#### 注意:
1. ppocr_keys_v1.txt是中文字典文件,如果使用的 nb 模型是英文数字或其他语言的模型,需要更换为对应语言的字典。
PaddleOCR 在ppocr/utils/下存放了多种字典,包括:
```
dict/french_dict.txt # 法语字典
dict/german_dict.txt # 德语字典
ic15_dict.txt # 英文字典
dict/japan_dict.txt # 日语字典
dict/korean_dict.txt # 韩语字典
ppocr_keys_v1.txt # 中文字典
...
```
2. `config.txt` 包含了检测器、分类器的超参数,如下:
```
max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使得图像最长边为960
det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显
det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小
det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本
use_direction_classify 0 # 是否使用方向分类器,0表示不使用,1表示使用
```
5. 启动调试
上述步骤完成后就可以使用adb将文件push到手机上运行,步骤如下:
```
# 执行编译,得到可执行文件ocr_db_crnn
make -j
# 将编译的可执行文件移动到debug文件夹中
mv ocr_db_crnn ./debug/
# 将debug文件夹push到手机上
adb push debug /data/local/tmp/
adb shell
cd /data/local/tmp/debug
export LD_LIBRARY_PATH=${PWD}:$LD_LIBRARY_PATH
# 开始使用,ocr_db_crnn可执行文件的使用方式为:
# ./ocr_db_crnn 检测模型文件 方向分类器模型文件 识别模型文件 测试图像路径 字典文件路径
./ocr_db_crnn ch_ppocr_mobile_v2.0_det_slim_opt.nb ch_ppocr_mobile_v2.0_rec_slim_opt.nb ch_ppocr_mobile_v2.0_cls_slim_opt.nb ./11.jpg ppocr_keys_v1.txt
```
如果对代码做了修改,则需要重新编译并push到手机上。
运行效果如下:
<div align="center">
<img src="imgs/lite_demo.png" width="600">
</div>
## FAQ
Q1:如果想更换模型怎么办,需要重新按照流程走一遍吗?
A1:如果已经走通了上述步骤,更换模型只需要替换 .nb 模型文件即可,同时要注意更新字典
Q2:换一个图测试怎么做?
A2:替换debug下的.jpg测试图像为你想要测试的图像,adb push 到手机上即可
Q3:如何封装到手机APP中?
A3:此demo旨在提供能在手机上运行OCR的核心算法部分,PaddleOCR/deploy/android_demo是将这个demo封装到手机app的示例,供参考
# Tutorial of PaddleOCR Mobile deployment
This tutorial will introduce how to use [Paddle Lite](https://github.com/PaddlePaddle/Paddle-Lite) to deploy paddleOCR ultra-lightweight Chinese and English detection models on mobile phones.
paddle-lite is a lightweight inference engine for PaddlePaddle. It provides efficient inference capabilities for mobile phones and IoTs, and extensively integrates cross-platform hardware to provide lightweight deployment solutions for end-side deployment issues.
## 1. Preparation
### 运行准备
- Computer (for Compiling Paddle Lite)
- Mobile phone (arm7 or arm8)
### 1.1 Prepare the cross-compilation environment
The cross-compilation environment is used to compile C++ demos of Paddle Lite and PaddleOCR.
Supports multiple development environments.
For the compilation process of different development environments, please refer to the corresponding documents.
1. [Docker](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#docker)
2. [Linux](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#linux)
3. [MAC OS](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_env.html#mac-os)
### 1.2 Prepare Paddle-Lite library
There are two ways to obtain the Paddle-Lite library:
- 1. Download directly, the download link of the Paddle-Lite library is as follows:
| Platform | Paddle-Lite library download link |
|---|---|
|Android|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.android.armv7.gcc.c++_shared.with_extra.with_cv.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.android.armv8.gcc.c++_shared.with_extra.with_cv.tar.gz)|
|IOS|[arm7](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.ios.armv7.with_cv.with_extra.with_log.tiny_publish.tar.gz) / [arm8](https://github.com/PaddlePaddle/Paddle-Lite/releases/download/v2.9/inference_lite_lib.ios.armv8.with_cv.with_extra.with_log.tiny_publish.tar.gz)|
Note: 1. The above Paddle-Lite library is compiled from the Paddle-Lite 2.9 branch. For more information about Paddle-Lite 2.9, please refer to [link](https://github.com/PaddlePaddle/Paddle-Lite/releases/tag/v2.9).
- 2. [Recommended] Compile Paddle-Lite to get the prediction library. The compilation method of Paddle-Lite is as follows:
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
# Switch to Paddle-Lite release/v2.8 stable branch
git checkout release/v2.8
./lite/tools/build_android.sh --arch=armv8 --with_cv=ON --with_extra=ON
```
Note: When compiling Paddle-Lite to obtain the Paddle-Lite library, you need to turn on the two options `--with_cv=ON --with_extra=ON`, `--arch` means the `arm` version, here is designated as armv8,
More compilation commands refer to the introduction [link](https://paddle-lite.readthedocs.io/zh/latest/source_compile/compile_andriod.html)
After directly downloading the Paddle-Lite library and decompressing it, you can get the `inference_lite_lib.android.armv8/` folder, and the Paddle-Lite library obtained by compiling Paddle-Lite is located
`Paddle-Lite/build.lite.android.armv8.gcc/inference_lite_lib.android.armv8/` folder.
The structure of the prediction library is as follows:
```
inference_lite_lib.android.armv8/
|-- cxx C++ prebuild library
| |-- include C++
| | |-- paddle_api.h
| | |-- paddle_image_preprocess.h
| | |-- paddle_lite_factory_helper.h
| | |-- paddle_place.h
| | |-- paddle_use_kernels.h
| | |-- paddle_use_ops.h
| | `-- paddle_use_passes.h
| `-- lib C++ library
| |-- libpaddle_api_light_bundled.a C++ static library
| `-- libpaddle_light_api_shared.so C++ dynamic library
|-- java Java library
| |-- jar
| | `-- PaddlePredictor.jar
| |-- so
| | `-- libpaddle_lite_jni.so
| `-- src
|-- demo C++ and Java demo
| |-- cxx C++ demo
| `-- java Java demo
```
## 2 Run
### 2.1 Inference Model Optimization
Paddle Lite provides a variety of strategies to automatically optimize the original training model, including quantization, sub-graph fusion, hybrid scheduling, Kernel optimization and so on. In order to make the optimization process more convenient and easy to use, Paddle Lite provide opt tools to automatically complete the optimization steps and output a lightweight, optimal executable model.
If you have prepared the model file ending in .nb, you can skip this step.
The following table also provides a series of models that can be deployed on mobile phones to recognize Chinese. You can directly download the optimized model.
|Version|Introduction|Model size|Detection model|Text Direction model|Recognition model|Paddle-Lite branch|
|---|---|---|---|---|---|---|
|V2.0|extra-lightweight chinese OCR optimized model|7.8M|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb)|[download lin](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[download lin](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_opt.nb)|v2.9|
|V2.0(slim)|extra-lightweight chinese OCR optimized model|3.3M|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_slim_opt.nb)|v2.9|
If you directly use the model in the above table for deployment, you can skip the following steps and directly read [Section 2.2](#2.2 Run optimized model on Phone).
If the model to be deployed is not in the above table, you need to follow the steps below to obtain the optimized model.
The `opt` tool can be obtained by compiling Paddle Lite.
```
git clone https://github.com/PaddlePaddle/Paddle-Lite.git
cd Paddle-Lite
git checkout release/v2.9
./lite/tools/build.sh build_optimize_tool
```
After the compilation is complete, the opt file is located under build.opt/lite/api/, You can view the operating options and usage of opt in the following ways:
```
cd build.opt/lite/api/
./opt
```
|Options|Description|
|---|---|
|--model_dir|The path of the PaddlePaddle model to be optimized (non-combined form)|
|--model_file|The network structure file path of the PaddlePaddle model (combined form) to be optimized|
|--param_file|The weight file path of the PaddlePaddle model (combined form) to be optimized|
|--optimize_out_type|Output model type, currently supports two types: protobuf and naive_buffer, among which naive_buffer is a more lightweight serialization/deserialization implementation. If you need to perform model prediction on the mobile side, please set this option to naive_buffer. The default is protobuf|
|--optimize_out|The output path of the optimized model|
|--valid_targets|The executable backend of the model, the default is arm. Currently it supports x86, arm, opencl, npu, xpu, multiple backends can be specified at the same time (separated by spaces), and Model Optimize Tool will automatically select the best method. If you need to support Huawei NPU (DaVinci architecture NPU equipped with Kirin 810/990 Soc), it should be set to npu, arm|
|--record_tailoring_info|When using the function of cutting library files according to the model, set this option to true to record the kernel and OP information contained in the optimized model. The default is false|
`--model_dir` is suitable for the non-combined mode of the model to be optimized, and the inference model of PaddleOCR is the combined mode, that is, the model structure and model parameters are stored in a single file.
The following takes the ultra-lightweight Chinese model of PaddleOCR as an example to introduce the use of the compiled opt file to complete the conversion of the inference model to the Paddle-Lite optimized model
```
# [Recommendation] Download the Chinese and English inference model of PaddleOCR V2.0
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_slim_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_slim_infer.tar
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_cls_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_slim_infer.tar
# Convert V2.0 detection model
./opt --model_file=./ch_ppocr_mobile_v2.0_det_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_det_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
# Convert V2.0 recognition model
./opt --model_file=./ch_ppocr_mobile_v2.0_rec_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_rec_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
# Convert V2.0 angle classifier model
./opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer
```
After the conversion is successful, there will be more files ending with `.nb` in the inference model directory, which is the successfully converted model file.
<a name="2.2 Run optimized model on Phone"></a>
### 2.2 Run optimized model on Phone
Some preparatory work is required first.
1. Prepare an Android phone with arm8. If the compiled prediction library and opt file are armv7, you need an arm7 phone and modify ARM_ABI = arm7 in the Makefile.
2. Make sure the phone is connected to the computer, open the USB debugging option of the phone, and select the file transfer mode.
3. Install the adb tool on the computer.
3.1. Install ADB for MAC:
```
brew cask install android-platform-tools
```
3.2. Install ADB for Linux
```
sudo apt update
sudo apt install -y wget adb
```
3.3. Install ADB for windows
To install on win, you need to go to Google's Android platform to download the adb package for installation:[link](https://developer.android.com/studio)
Verify whether adb is installed successfully
```
adb devices
```
If there is device output, it means the installation is successful。
```
List of devices attached
744be294 device
```
4. Prepare optimized models, prediction library files, test images and dictionary files used.
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR/deploy/lite/
# run prepare.sh
sh prepare.sh /{lite prediction library path}/inference_lite_lib.android.armv8
#
cd /{lite prediction library path}/inference_lite_lib.android.armv8/
cd demo/cxx/ocr/
# copy paddle-lite C++ .so file to debug/ directory
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./debug/
cd inference_lite_lib.android.armv8/demo/cxx/ocr/
cp ../../../cxx/lib/libpaddle_light_api_shared.so ./debug/
```
Prepare the test image, taking PaddleOCR/doc/imgs/11.jpg as an example, copy the image file to the demo/cxx/ocr/debug/ folder. Prepare the model files optimized by the lite opt tool, ch_det_mv3_db_opt.nb, ch_rec_mv3_crnn_opt.nb, and place them under the demo/cxx/ocr/debug/ folder.
The structure of the OCR demo is as follows after the above command is executed:
```
demo/cxx/ocr/
|-- debug/
| |--ch_ppocr_mobile_v2.0_det_slim_opt.nb Detection model
| |--ch_ppocr_mobile_v2.0_rec_slim_opt.nb Recognition model
| |--ch_ppocr_mobile_v2.0_cls_slim_opt.nb Text direction classification model
| |--11.jpg Image for OCR
| |--ppocr_keys_v1.txt Dictionary file
| |--libpaddle_light_api_shared.so C++ .so file
| |--config.txt Config file
|-- config.txt Config file
|-- cls_process.cc Pre-processing and post-processing files for the angle classifier
|-- cls_process.h
|-- crnn_process.cc Pre-processing and post-processing files for the CRNN model
|-- crnn_process.h
|-- db_post_process.cc Pre-processing and post-processing files for the DB model
|-- db_post_process.h
|-- Makefile
|-- ocr_db_crnn.cc C++ main code
```
#### 注意:
1. `ppocr_keys_v1.txt` is a Chinese dictionary file. If the nb model is used for English recognition or other language recognition, dictionary file should be replaced with a dictionary of the corresponding language. PaddleOCR provides a variety of dictionaries under ppocr/utils/, including:
```
dict/french_dict.txt # french
dict/german_dict.txt # german
ic15_dict.txt # english
dict/japan_dict.txt # japan
dict/korean_dict.txt # korean
ppocr_keys_v1.txt # chinese
```
2. `config.txt` of the detector and classifier, as shown below:
```
max_side_len 960 # Limit the maximum image height and width to 960
det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result
det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate
det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text
use_direction_classify 0 # Whether to use the direction classifier, 0 means not to use, 1 means to use
```
5. Run Model on phone
After the above steps are completed, you can use adb to push the file to the phone to run, the steps are as follows:
```
# Execute the compilation and get the executable file ocr_db_crnn
make -j
# Move the compiled executable file to the debug folder
mv ocr_db_crnn ./debug/
# Push the debug folder to the phone
adb push debug /data/local/tmp/
adb shell
cd /data/local/tmp/debug
export LD_LIBRARY_PATH=${PWD}:$LD_LIBRARY_PATH
# The use of ocr_db_crnn is:
# ./ocr_db_crnn Detection model file Orientation classifier model file Recognition model file Test image path Dictionary file path
./ocr_db_crnn ch_ppocr_mobile_v2.0_det_opt.nb ch_ppocr_mobile_v2.0_rec_opt.nb ch_ppocr_mobile_v2.0_cls_opt.nb ./11.jpg ppocr_keys_v1.txt
```
If you modify the code, you need to recompile and push to the phone.
The outputs are as follows:
<div align="center">
<img src="imgs/lite_demo.png" width="600">
</div>
## FAQ
Q1: What if I want to change the model, do I need to run it again according to the process?
A1: If you have performed the above steps, you only need to replace the .nb model file to complete the model replacement.
Q2: How to test with another picture?
A2: Replace the .jpg test image under ./debug with the image you want to test, and run adb push to push new image to the phone.
Q3: How to package it into the mobile APP?
A3: This demo aims to provide the core algorithm part that can run OCR on mobile phones. Further, PaddleOCR/deploy/android_demo is an example of encapsulating this demo into a mobile app for reference.
...@@ -45,26 +45,17 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -45,26 +45,17 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
## 快速启动训练 ## 快速启动训练
首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列, 首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/develop/ppcls/modeling/architectures)中的模型更换backbone,
对应的backbone预训练模型可以从[PaddleClas repo 主页中找到下载链接](https://github.com/PaddlePaddle/PaddleClas#mobile-series)
```shell ```shell
cd PaddleOCR/ cd PaddleOCR/
# 根据backbone的不同选择下载对应的预训练模型
# 下载MobileNetV3的预训练模型 # 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
# 或,下载ResNet18_vd的预训练模型 # 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
# 或,下载ResNet50_vd的预训练模型 # 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
# 解压预训练模型文件,以MobileNetV3为例
tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
└─ conv_last_bn_mean
└─ conv_last_bn_offset
└─ conv_last_bn_scale
└─ conv_last_bn_variance
└─ ......
``` ```
...@@ -120,16 +111,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat ...@@ -120,16 +111,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
测试单张图像的检测效果 测试单张图像的检测效果
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"
``` ```
测试DB模型时,调整后处理阈值, 测试DB模型时,调整后处理阈值,
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
``` ```
测试文件夹下所有图像的检测效果 测试文件夹下所有图像的检测效果
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"
``` ```
...@@ -12,7 +12,7 @@ inference 模型(`paddle.jit.save`保存的模型) ...@@ -12,7 +12,7 @@ inference 模型(`paddle.jit.save`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型) - [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型)
- [方向分类模型转inference模型](#方向分类模型转inference模型) - [方向分类模型转inference模型](#方向分类模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理) - [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
...@@ -49,10 +49,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi ...@@ -49,10 +49,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi
# -c 后面设置训练算法的yml配置文件 # -c 后面设置训练算法的yml配置文件
# -o 配置可选参数 # -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。 # Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
# Global.load_static_weights 参数需要设置为 False。
# Global.save_inference_dir参数设置转换的模型将保存的地址。 # Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_db/ python3 tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_inference_dir=./inference/det_db/
``` ```
转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.pretrained_model`参数,其指向训练中保存的模型参数文件。 转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.pretrained_model`参数,其指向训练中保存的模型参数文件。
转换成功后,在模型保存目录下有三个文件: 转换成功后,在模型保存目录下有三个文件:
...@@ -76,10 +75,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi ...@@ -76,10 +75,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi
# -c 后面设置训练算法的yml配置文件 # -c 后面设置训练算法的yml配置文件
# -o 配置可选参数 # -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。 # Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
# Global.load_static_weights 参数需要设置为 False。
# Global.save_inference_dir参数设置转换的模型将保存的地址。 # Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn/ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn/
``` ```
**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。 **注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
...@@ -105,10 +103,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi ...@@ -105,10 +103,9 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobi
# -c 后面设置训练算法的yml配置文件 # -c 后面设置训练算法的yml配置文件
# -o 配置可选参数 # -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。 # Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
# Global.load_static_weights 参数需要设置为 False。
# Global.save_inference_dir参数设置转换的模型将保存的地址。 # Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/cls/ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.save_inference_dir=./inference/cls/
``` ```
转换成功后,在目录下有三个文件: 转换成功后,在目录下有三个文件:
...@@ -164,7 +161,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_di ...@@ -164,7 +161,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_di
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar) ),可以使用如下命令进行转换: 首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar) ),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model=./det_r50_vd_db_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_db python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model=./det_r50_vd_db_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_db
``` ```
DB文本检测模型推理,可以执行如下命令: DB文本检测模型推理,可以执行如下命令:
...@@ -185,7 +182,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_ ...@@ -185,7 +182,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换: 首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_east
``` ```
**EAST文本检测模型推理,需要设置参数`--det_algorithm="EAST"`**,可以执行如下命令: **EAST文本检测模型推理,需要设置参数`--det_algorithm="EAST"`**,可以执行如下命令:
...@@ -205,7 +202,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img ...@@ -205,7 +202,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
#### (1). 四边形文本检测模型(ICDAR2015) #### (1). 四边形文本检测模型(ICDAR2015)
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换: 首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_ic15
``` ```
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`**,可以执行如下命令: **SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`**,可以执行如下命令:
...@@ -220,7 +217,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img ...@@ -220,7 +217,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换: 首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
``` ```
...@@ -270,7 +267,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153) ...@@ -270,7 +267,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.98458153)
的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar) ),可以使用如下命令进行转换: 的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar) ),可以使用如下命令进行转换:
``` ```
python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn
``` ```
CRNN 文本识别模型推理,可以执行如下命令: CRNN 文本识别模型推理,可以执行如下命令:
...@@ -362,17 +359,18 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982] ...@@ -362,17 +359,18 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
<a name="超轻量中文OCR模型推理"></a> <a name="超轻量中文OCR模型推理"></a>
### 1. 超轻量中文OCR模型推理 ### 1. 超轻量中文OCR模型推理
在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir``rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。可视化识别结果默认保存到 ./inference_results 文件夹里面。 在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir``rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。`use_mp`表示是否使用多进程。`total_process_num`表示在使用多进程时的进程数。可视化识别结果默认保存到 ./inference_results 文件夹里面。
``` ```shell
# 使用方向分类器 # 使用方向分类器
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=true python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=true
# 不使用方向分类器 # 不使用方向分类器
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false
```
# 使用多进程
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false --use_mp=True --total_process_num=6
```
......
...@@ -104,27 +104,16 @@ python3 generate_multi_language_configs.py -l it \ ...@@ -104,27 +104,16 @@ python3 generate_multi_language_configs.py -l it \
| german_mobile_v2.0_rec |德文识别|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) | | german_mobile_v2.0_rec |德文识别|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
| korean_mobile_v2.0_rec |韩文识别|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) | | korean_mobile_v2.0_rec |韩文识别|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
| japan_mobile_v2.0_rec |日文识别|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) | | japan_mobile_v2.0_rec |日文识别|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
| it_mobile_v2.0_rec |意大利文识别|rec_it_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
| xi_mobile_v2.0_rec |西班牙文识别|rec_xi_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_train.tar) |
| pu_mobile_v2.0_rec |葡萄牙文识别|rec_pu_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_train.tar) |
| ru_mobile_v2.0_rec |俄罗斯文识别|rec_ru_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_train.tar) |
| ar_mobile_v2.0_rec |阿拉伯文识别|rec_ar_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_train.tar) |
| hi_mobile_v2.0_rec |印地文识别|rec_hi_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_train.tar) |
| chinese_cht_mobile_v2.0_rec |中文繁体识别|rec_chinese_cht_lite_train.yml|5.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) | | chinese_cht_mobile_v2.0_rec |中文繁体识别|rec_chinese_cht_lite_train.yml|5.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) |
| ug_mobile_v2.0_rec |维吾尔文识别|rec_ug_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_train.tar) |
| fa_mobile_v2.0_rec |波斯文识别|rec_fa_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_train.tar) |
| ur_mobile_v2.0_rec |乌尔都文识别|rec_ur_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_train.tar) |
| rs_mobile_v2.0_rec |塞尔维亚文(latin)识别|rec_rs_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_train.tar) |
| oc_mobile_v2.0_rec |欧西坦文识别|rec_oc_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_train.tar) |
| mr_mobile_v2.0_rec |马拉地文识别|rec_mr_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_train.tar) |
| ne_mobile_v2.0_rec |尼泊尔文识别|rec_ne_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_train.tar) |
| rsc_mobile_v2.0_rec |塞尔维亚文(cyrillic)识别|rec_rsc_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_train.tar) |
| bg_mobile_v2.0_rec |保加利亚文识别|rec_bg_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_train.tar) |
| uk_mobile_v2.0_rec |乌克兰文识别|rec_uk_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_train.tar) |
| be_mobile_v2.0_rec |白俄罗斯文识别|rec_be_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_train.tar) |
| te_mobile_v2.0_rec |泰卢固文识别|rec_te_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) | | te_mobile_v2.0_rec |泰卢固文识别|rec_te_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) |
| ka_mobile_v2.0_rec |卡纳达文识别|rec_ka_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) | | ka_mobile_v2.0_rec |卡纳达文识别|rec_ka_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) |
| ta_mobile_v2.0_rec |泰米尔文识别|rec_ta_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) | | ta_mobile_v2.0_rec |泰米尔文识别|rec_ta_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) |
| latin_mobile_v2.0_rec | 拉丁文识别 | [rec_latin_lite_train.yml](../../configs/rec/multi_language/rec_latin_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_train.tar) |
| arabic_mobile_v2.0_rec | 阿拉伯字母 | [rec_arabic_lite_train.yml](../../configs/rec/multi_language/rec_arabic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_train.tar) |
| cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
更多支持语种请参考: [多语言模型](./multi_languages.md)
<a name="文本方向分类模型"></a> <a name="文本方向分类模型"></a>
......
# 多语言模型
**近期更新**
- 2021.4.9 支持**80种**语言的检测和识别
- 2021.4.9 支持**轻量高精度**英文模型检测识别
PaddleOCR 旨在打造一套丰富、领先、且实用的OCR工具库,不仅提供了通用场景下的中英文模型,也提供了专门在英文场景下训练的模型,
和覆盖[80个语言](#语种缩写)的小语种模型。
其中英文模型支持,大小写字母和常见标点的检测识别,并优化了空格字符的识别:
<div align="center">
<img src="../imgs_results/multi_lang/img_12.jpg" width="900" height="300">
</div>
小语种模型覆盖了拉丁语系、阿拉伯语系、中文繁体、韩语、日语等等:
<div align="center">
<img src="../imgs_results/multi_lang/japan_2.jpg" width="600" height="300">
<img src="../imgs_results/multi_lang/french_0.jpg" width="300" height="300">
<img src="../imgs_results/multi_lang/korean_0.jpg" width="500" height="300">
<img src="../imgs_results/multi_lang/arabic_0.jpg" width="300" height="300">
</div>
本文档将简要介绍小语种模型的使用方法。
- [1 安装](#安装)
- [1.1 paddle 安装](#paddle安装)
- [1.2 paddleocr package 安装](#paddleocr_package_安装)
- [2 快速使用](#快速使用)
- [2.1 命令行运行](#命令行运行)
- [2.2 python 脚本运行](#python_脚本运行)
- [3 自定义训练](#自定义训练)
- [4 预测部署](#预测部署)
- [4 支持语种及缩写](#语种缩写)
<a name="安装"></a>
## 1 安装
<a name="paddle安装"></a>
### 1.1 paddle 安装
```
# cpu
pip install paddlepaddle
# gpu
pip install paddlepaddle-gpu
```
<a name="paddleocr_package_安装"></a>
### 1.2 paddleocr package 安装
pip 安装
```
pip install "paddleocr>=2.0.6" # 推荐使用2.0.6版本
```
本地构建并安装
```
python3 setup.py bdist_wheel
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
```
<a name="快速使用"></a>
## 2 快速使用
<a name="命令行运行"></a>
### 2.1 命令行运行
查看帮助信息
```
paddleocr -h
```
* 整图预测(检测+识别)
Paddleocr目前支持80个语种,可以通过修改--lang参数进行切换,具体支持的[语种](#语种缩写)可查看表格。
``` bash
paddleocr --image_dir doc/imgs_en/254.jpg --lang=en
```
<div align="center">
<img src="../imgs_en/254.jpg" width="300" height="600">
<img src="../imgs_results/multi_lang/img_02.jpg" width="600" height="600">
</div>
结果是一个list,每个item包含了文本框,文字和识别置信度
```text
[('PHO CAPITAL', 0.95723116), [[66.0, 50.0], [327.0, 44.0], [327.0, 76.0], [67.0, 82.0]]]
[('107 State Street', 0.96311164), [[72.0, 90.0], [451.0, 84.0], [452.0, 116.0], [73.0, 121.0]]]
[('Montpelier Vermont', 0.97389287), [[69.0, 132.0], [501.0, 126.0], [501.0, 158.0], [70.0, 164.0]]]
[('8022256183', 0.99810505), [[71.0, 175.0], [363.0, 170.0], [364.0, 202.0], [72.0, 207.0]]]
[('REG 07-24-201706:59 PM', 0.93537045), [[73.0, 299.0], [653.0, 281.0], [654.0, 318.0], [74.0, 336.0]]]
[('045555', 0.99346405), [[509.0, 331.0], [651.0, 325.0], [652.0, 356.0], [511.0, 362.0]]]
[('CT1', 0.9988654), [[535.0, 367.0], [654.0, 367.0], [654.0, 406.0], [535.0, 406.0]]]
......
```
* 识别预测
```bash
paddleocr --image_dir doc/imgs_words_en/word_308.png --det false --lang=en
```
结果是一个tuple,返回识别结果和识别置信度
```text
(0.99879867, 'LITTLE')
```
* 检测预测
```
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
```
结果是一个list,每个item只包含文本框
```
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
......
```
<a name="python_脚本运行"></a>
### 2.2 python 脚本运行
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中 :
* 整图预测(检测+识别)
```
from paddleocr import PaddleOCR, draw_ocr
# 同样也是通过修改 lang 参数切换语种
ocr = PaddleOCR(lang="korean") # 首次执行会自动下载模型文件
img_path = 'doc/imgs/korean_1.jpg '
result = ocr.ocr(img_path)
# 可通过参数控制单独执行识别、检测
# result = ocr.ocr(img_path, det=False) 只执行识别
# result = ocr.ocr(img_path, rec=False) 只执行检测
# 打印检测框和识别结果
for line in result:
print(line)
# 可视化
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
结果可视化:
<div align="center">
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/korean.jpg" width="800">
</div>
ppocr 还支持方向分类, 更多使用方式请参考:[whl包使用说明](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md)
<a name="自定义训练"></a>
## 3 自定义训练
ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别模型可以参考 [法语配置文件](../../configs/rec/multi_language/rec_french_lite_train.yml)
修改训练数据路径、字典等参数。
具体数据准备、训练过程可参考:[文本检测](../doc_ch/detection.md)[文本识别](../doc_ch/recognition.md),更多功能如预测部署、
数据标注等功能可以阅读完整的[文档教程](../../README_ch.md)
<a name="预测部署"></a>
## 4 预测部署
除了安装whl包进行快速预测,ppocr 也提供了多种预测部署方式,如有需求可阅读相关文档:
- [基于Python脚本预测引擎推理](./inference.md)
- [基于C++预测引擎推理](../../deploy/cpp_infer/readme.md)
- [服务化部署](../../deploy/hubserving/readme.md)
- [端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme.md)
- [Benchmark](./benchmark.md)
<a name="语种缩写"></a>
## 5 支持语种及缩写
| 语种 | 描述 | 缩写 | | 语种 | 描述 | 缩写 |
| --- | --- | --- | ---|--- | --- | --- |
|中文|chinese and english|ch| |保加利亚文|Bulgarian |bg|
|英文|english|en| |乌克兰文|Ukranian|uk|
|法文|french|fr| |白俄罗斯文|Belarusian|be|
|德文|german|german| |泰卢固文|Telugu |te|
|日文|japan|japan| | |阿巴扎文|Abaza |abq|
|韩文|korean|korean| |泰米尔文|Tamil |ta|
|中文繁体|chinese traditional |ch_tra| |南非荷兰文 |Afrikaans |af|
|意大利文| Italian |it| |阿塞拜疆文 |Azerbaijani |az|
|西班牙文|Spanish |es| |波斯尼亚文|Bosnian|bs|
|葡萄牙文| Portuguese|pt| |捷克文|Czech|cs|
|俄罗斯文|Russia|ru| |威尔士文 |Welsh |cy|
|阿拉伯文|Arabic|ar| |丹麦文 |Danish|da|
|印地文|Hindi|hi| |爱沙尼亚文 |Estonian |et|
|维吾尔|Uyghur|ug| |爱尔兰文 |Irish |ga|
|波斯文|Persian|fa| |克罗地亚文|Croatian |hr|
|乌尔都文|Urdu|ur| |匈牙利文|Hungarian |hu|
|塞尔维亚文(latin)| Serbian(latin) |rs_latin| |印尼文|Indonesian|id|
|欧西坦文|Occitan |oc| |冰岛文 |Icelandic|is|
|马拉地文|Marathi|mr| |库尔德文 |Kurdish|ku|
|尼泊尔文|Nepali|ne| |立陶宛文|Lithuanian |lt|
|塞尔维亚文(cyrillic)|Serbian(cyrillic)|rs_cyrillic| |拉脱维亚文 |Latvian |lv|
|毛利文|Maori|mi| | 达尔瓦文|Dargwa |dar|
|马来文 |Malay|ms| | 因古什文|Ingush |inh|
|马耳他文 |Maltese |mt| | 拉克文|Lak |lbe|
|荷兰文 |Dutch |nl| | 莱兹甘文|Lezghian |lez|
|挪威文 |Norwegian |no| |塔巴萨兰文 |Tabassaran |tab|
|波兰文|Polish |pl| | 比尔哈文|Bihari |bh|
| 罗马尼亚文|Romanian |ro| | 迈蒂利文|Maithili |mai|
| 斯洛伐克文|Slovak |sk| | 昂加文|Angika |ang|
| 斯洛文尼亚文|Slovenian |sl| | 孟加拉文|Bhojpuri |bho|
| 阿尔巴尼亚文|Albanian |sq| | 摩揭陀文 |Magahi |mah|
| 瑞典文|Swedish |sv| | 那格浦尔文|Nagpur |sck|
| 西瓦希里文|Swahili |sw| | 尼瓦尔文|Newari |new|
| 塔加洛文|Tagalog |tl| | 保加利亚文 |Goan Konkani|gom|
| 土耳其文|Turkish |tr| | 沙特阿拉伯文|Saudi Arabia|sa|
| 乌兹别克文|Uzbek |uz| | 阿瓦尔文|Avar |ava|
| 越南文|Vietnamese |vi| | 阿瓦尔文|Avar |ava|
| 蒙古文|Mongolian |mn| | 阿迪赫文|Adyghe |ady|
# 端对端OCR算法-PGNet
- [一、简介](#简介)
- [二、环境配置](#环境配置)
- [三、快速使用](#快速使用)
- [四、模型训练、评估、推理](#模型训练、评估、推理)
<a name="简介"></a>
## 一、简介
OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。
### PGNet算法介绍
近些年来,端对端OCR算法得到了良好的发展,包括MaskTextSpotter系列、TextSnake、TextDragon、PGNet系列等算法。在这些算法中,PGNet算法具备其他算法不具备的优势,包括:
- 设计PGNet loss指导训练,不需要字符级别的标注
- 不需要NMS和ROI相关操作,加速预测
- 提出预测文本行内的阅读顺序模块;
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
- 精度更高,预测速度更快
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,算法原理图如下所示:
![](../pgnet_framework.png)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
其检测识别效果图如下:
![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png)
### 性能指标
测试集: Total Text
测试环境: NVIDIA Tesla V100-SXM2-16GB
|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|下载|
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|48.73 (size=768)|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)|
*note:PaddleOCR里的PGNet实现针对预测速度做了优化,在精度下降可接受范围内,可以显著提升端对端预测速度*
<a name="环境配置"></a>
## 二、环境配置
请先参考[快速安装](./installation.md)配置PaddleOCR运行环境。
<a name="快速使用"></a>
## 三、快速使用
### inference模型下载
本节以训练好的端到端模型为例,快速使用模型预测,首先下载训练好的端到端inference模型[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
```
mkdir inference && cd inference
# 下载英文端到端模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar
```
* windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下
解压完毕后应有如下文件结构:
```
├── e2e_server_pgnetA_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
### 单张图像或者图像集合预测
```bash
# 预测image_dir指定的单张图像
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# 预测image_dir指定的图像集合
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# 如果想使用CPU进行预测,需设置use_gpu参数为False
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
```
### 可视化结果
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="模型训练、评估、推理"></a>
## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据
下载解压[totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
```
/PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据
|- img11.jpg
| ...
|- train.txt # total_text数据集的训练标注
```
total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔:
```
" 图像文件名 json.dumps编码的图像标注信息"
rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
### 启动训练
PGNet训练分为两个步骤:step1: 在合成数据上训练,得到预训练模型,此时模型精度依然较低;step2: 加载预训练模型,在totaltext数据集上训练;为快速训练,我们直接提供了step1的预训练模型。
```shell
cd PaddleOCR/
下载step1 预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar
可以得到以下的文件格式
./pretrain_models/train_step1/
└─ best_accuracy.pdopt
└─ best_accuracy.states
└─ best_accuracy.pdparams
```
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```shell
# 单机单卡训练 e2e 模型
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
```
上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```
#### 断点训练
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件`e2e_r50_vd_pg.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
```shell
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
```
### 模型预测
测试单张图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
```
测试文件夹下所有图像的端到端识别效果
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
```
### 预测推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
对于弯曲文本样例
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
...@@ -138,7 +138,7 @@ PaddleOCR内置了一部分字典,可以按需使用。 ...@@ -138,7 +138,7 @@ PaddleOCR内置了一部分字典,可以按需使用。
`ppocr/utils/dict/german_dict.txt` 是一个包含131个字符的德文字典 `ppocr/utils/dict/german_dict.txt` 是一个包含131个字符的德文字典
`ppocr/utils/dict/en_dict.txt` 是一个包含63个字符的英文字典 `ppocr/utils/en_dict.txt` 是一个包含96个字符的英文字典
...@@ -285,7 +285,7 @@ Eval: ...@@ -285,7 +285,7 @@ Eval:
<a name="小语种"></a> <a name="小语种"></a>
#### 2.3 小语种 #### 2.3 小语种
PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml) PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)
您有两种方式创建所需的配置文件: 您有两种方式创建所需的配置文件:
...@@ -368,26 +368,12 @@ PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi ...@@ -368,26 +368,12 @@ PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german | | rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan | | rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean | | rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 意大利语 | it | | rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 | latin |
| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 西班牙语 | xi | | rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 | ar |
| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 葡萄牙语 | pu | | rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 | cyrillic |
| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 俄罗斯语 | ru | | rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 | devanagari |
| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯语 | ar |
| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 印地语 | hi | 更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 维吾尔语 | ug |
| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 波斯语 | fa |
| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌尔都语 | ur |
| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(latin)语 | rs |
| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 欧西坦语 | oc |
| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 马拉地语 | mr |
| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 尼泊尔语 | ne |
| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(cyrillic)语 | rsc |
| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 保加利亚语 | bg |
| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌克兰语 | uk |
| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 白俄罗斯语 | be |
| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰卢固语 | te |
| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 卡纳达语 | ka |
| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰米尔语 | ta |
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。 多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
......
...@@ -36,7 +36,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -36,7 +36,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -114,7 +114,7 @@ for line in result: ...@@ -114,7 +114,7 @@ for line in result:
from PIL import Image from PIL import Image
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -253,7 +253,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -253,7 +253,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -285,7 +285,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -285,7 +285,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -314,7 +314,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -314,7 +314,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
......
...@@ -38,28 +38,17 @@ If you want to train PaddleOCR on other datasets, please build the annotation fi ...@@ -38,28 +38,17 @@ If you want to train PaddleOCR on other datasets, please build the annotation fi
## TRAINING ## TRAINING
First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures) to replace backbone according to your needs. First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/develop/ppcls/modeling/architectures) to replace backbone according to your needs.
And the responding download link of backbone pretrain weights can be found in [PaddleClas repo](https://github.com/PaddlePaddle/PaddleClas#mobile-series).
```shell ```shell
cd PaddleOCR/ cd PaddleOCR/
# Download the pre-trained model of MobileNetV3 # Download the pre-trained model of MobileNetV3
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
# or, download the pre-trained model of ResNet18_vd # or, download the pre-trained model of ResNet18_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_vd_pretrained.pdparams
# or, download the pre-trained model of ResNet50_vd # or, download the pre-trained model of ResNet50_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
# decompressing the pre-training model file, take MobileNetV3 as an example
tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
# Note: After decompressing the backbone pre-training weight file correctly, the file list in the folder is as follows:
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
└─ conv_last_bn_mean
└─ conv_last_bn_offset
└─ conv_last_bn_scale
└─ conv_last_bn_variance
└─ ......
```
#### START TRAINING #### START TRAINING
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.* *If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
...@@ -113,16 +102,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat ...@@ -113,16 +102,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
Test the detection result on a single image: Test the detection result on a single image:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"
``` ```
When testing the DB model, adjust the post-processing threshold: When testing the DB model, adjust the post-processing threshold:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
``` ```
Test the detection result on all images in the folder: Test the detection result on all images in the folder:
```shell ```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"
``` ```
...@@ -52,10 +52,9 @@ The above model is a DB algorithm trained with MobileNetV3 as the backbone. To c ...@@ -52,10 +52,9 @@ The above model is a DB algorithm trained with MobileNetV3 as the backbone. To c
# -c Set the training algorithm yml configuration file # -c Set the training algorithm yml configuration file
# -o Set optional parameters # -o Set optional parameters
# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams. # Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
# Global.load_static_weights needs to be set to False
# Global.save_inference_dir Set the address where the converted model will be saved. # Global.save_inference_dir Set the address where the converted model will be saved.
python3 tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_db/ python3 tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_inference_dir=./inference/det_db/
``` ```
When converting to an inference model, the configuration file used is the same as the configuration file used during training. In addition, you also need to set the `Global.pretrained_model` parameter in the configuration file. When converting to an inference model, the configuration file used is the same as the configuration file used during training. In addition, you also need to set the `Global.pretrained_model` parameter in the configuration file.
...@@ -80,10 +79,9 @@ The recognition model is converted to the inference model in the same way as the ...@@ -80,10 +79,9 @@ The recognition model is converted to the inference model in the same way as the
# -c Set the training algorithm yml configuration file # -c Set the training algorithm yml configuration file
# -o Set optional parameters # -o Set optional parameters
# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams. # Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
# Global.load_static_weights needs to be set to False
# Global.save_inference_dir Set the address where the converted model will be saved. # Global.save_inference_dir Set the address where the converted model will be saved.
python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn/ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn/
``` ```
If you have a model trained on your own dataset with a different dictionary file, please make sure that you modify the `character_dict_path` in the configuration file to your dictionary file path. If you have a model trained on your own dataset with a different dictionary file, please make sure that you modify the `character_dict_path` in the configuration file to your dictionary file path.
...@@ -109,10 +107,9 @@ The angle classification model is converted to the inference model in the same w ...@@ -109,10 +107,9 @@ The angle classification model is converted to the inference model in the same w
# -c Set the training algorithm yml configuration file # -c Set the training algorithm yml configuration file
# -o Set optional parameters # -o Set optional parameters
# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams. # Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
# Global.load_static_weights needs to be set to False
# Global.save_inference_dir Set the address where the converted model will be saved. # Global.save_inference_dir Set the address where the converted model will be saved.
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/cls/ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.save_inference_dir=./inference/cls/
``` ```
After the conversion is successful, there are two files in the directory: After the conversion is successful, there are two files in the directory:
...@@ -171,7 +168,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_d ...@@ -171,7 +168,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_d
First, convert the model saved in the DB text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)), you can use the following command to convert: First, convert the model saved in the DB text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model=./det_r50_vd_db_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_db python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model=./det_r50_vd_db_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_db
``` ```
DB text detection model inference, you can execute the following command: DB text detection model inference, you can execute the following command:
...@@ -192,7 +189,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -192,7 +189,7 @@ The visualized text detection results are saved to the `./inference_results` fol
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert: First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_east
``` ```
**For EAST text detection model inference, you need to set the parameter ``--det_algorithm="EAST"``**, run the following command: **For EAST text detection model inference, you need to set the parameter ``--det_algorithm="EAST"``**, run the following command:
...@@ -213,7 +210,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -213,7 +210,7 @@ The visualized text detection results are saved to the `./inference_results` fol
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert: First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15 python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_ic15
``` ```
**For SAST quadrangle text detection model inference, you need to set the parameter `--det_algorithm="SAST"`**, run the following command: **For SAST quadrangle text detection model inference, you need to set the parameter `--det_algorithm="SAST"`**, run the following command:
...@@ -230,7 +227,7 @@ The visualized text detection results are saved to the `./inference_results` fol ...@@ -230,7 +227,7 @@ The visualized text detection results are saved to the `./inference_results` fol
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert: First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert:
``` ```
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
``` ```
**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command: **For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
...@@ -279,7 +276,7 @@ Taking CRNN as an example, we introduce the recognition model inference based on ...@@ -279,7 +276,7 @@ Taking CRNN as an example, we introduce the recognition model inference based on
First, convert the model saved in the CRNN text recognition training process into an inference model. Taking the model based on Resnet34_vd backbone network, using MJSynth and SynthText (two English text recognition synthetic datasets) for training, as an example ([model download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)). It can be converted as follow: First, convert the model saved in the CRNN text recognition training process into an inference model. Taking the model based on Resnet34_vd backbone network, using MJSynth and SynthText (two English text recognition synthetic datasets) for training, as an example ([model download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)). It can be converted as follow:
``` ```
python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/rec_crnn python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o Global.pretrained_model=./rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn
``` ```
For CRNN text recognition model inference, execute the following commands: For CRNN text recognition model inference, execute the following commands:
...@@ -379,14 +376,18 @@ After executing the command, the prediction results (classification angle and sc ...@@ -379,14 +376,18 @@ After executing the command, the prediction results (classification angle and sc
<a name="LIGHTWEIGHT_CHINESE_MODEL"></a> <a name="LIGHTWEIGHT_CHINESE_MODEL"></a>
### 1. LIGHTWEIGHT CHINESE MODEL ### 1. LIGHTWEIGHT CHINESE MODEL
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model.The visualized recognition results are saved to the `./inference_results` folder by default. When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model. The parameter `use_mp` specifies whether to use multi-process to infer `total_process_num` specifies process number when using multi-process. The parameter . The visualized recognition results are saved to the `./inference_results` folder by default.
``` ```shell
# use direction classifier # use direction classifier
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=true python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=true
# not use use direction classifier # not use use direction classifier
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
# use multi-process
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false --use_mp=True --total_process_num=6
```
``` ```
After executing the command, the recognition result image is as follows: After executing the command, the recognition result image is as follows:
......
...@@ -102,27 +102,16 @@ python3 generate_multi_language_configs.py -l it \ ...@@ -102,27 +102,16 @@ python3 generate_multi_language_configs.py -l it \
| german_mobile_v2.0_rec |Lightweight model for German recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) | | german_mobile_v2.0_rec |Lightweight model for German recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) | | korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) | | japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) | | chinese_cht_mobile_v2.0_rec |Lightweight model for chinese cht recognition|rec_chinese_cht_lite_train.yml|5.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) |
| xi_mobile_v2.0_rec |Lightweight model for Spanish recognition|rec_xi_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_train.tar) |
| pu_mobile_v2.0_rec |Lightweight model for Portuguese recognition|rec_pu_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_train.tar) |
| ru_mobile_v2.0_rec |Lightweight model for Russia recognition|rec_ru_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_train.tar) |
| ar_mobile_v2.0_rec |Lightweight model for Arabic recognition|rec_ar_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_train.tar) |
| hi_mobile_v2.0_rec |Lightweight model for Hindi recognition|rec_hi_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_train.tar) |
| chinese_cht_mobile_v2.0_rec |Lightweight model for chinese traditional recognition|rec_chinese_cht_lite_train.yml|5.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) |
| ug_mobile_v2.0_rec |Lightweight model for Uyghur recognition|rec_ug_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_train.tar) |
| fa_mobile_v2.0_rec |Lightweight model for Persian recognition|rec_fa_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_train.tar) |
| ur_mobile_v2.0_rec |Lightweight model for Urdu recognition|rec_ur_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_train.tar) |
| rs_mobile_v2.0_rec |Lightweight model for Serbian(latin) recognition|rec_rs_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_train.tar) |
| oc_mobile_v2.0_rec |Lightweight model for Occitan recognition|rec_oc_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_train.tar) |
| mr_mobile_v2.0_rec |Lightweight model for Marathi recognition|rec_mr_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_train.tar) |
| ne_mobile_v2.0_rec |Lightweight model for Nepali recognition|rec_ne_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_train.tar) |
| rsc_mobile_v2.0_rec |Lightweight model for Serbian(cyrillic) recognition|rec_rsc_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_train.tar) |
| bg_mobile_v2.0_rec |Lightweight model for Bulgarian recognition|rec_bg_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_train.tar) |
| uk_mobile_v2.0_rec |Lightweight model for Ukranian recognition|rec_uk_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_train.tar) |
| be_mobile_v2.0_rec |Lightweight model for Belarusian recognition|rec_be_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_train.tar) |
| te_mobile_v2.0_rec |Lightweight model for Telugu recognition|rec_te_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) | | te_mobile_v2.0_rec |Lightweight model for Telugu recognition|rec_te_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) |
| ka_mobile_v2.0_rec |Lightweight model for Kannada recognition|rec_ka_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) | | ka_mobile_v2.0_rec |Lightweight model for Kannada recognition|rec_ka_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) |
| ta_mobile_v2.0_rec |Lightweight model for Tamil recognition|rec_ta_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) | | ta_mobile_v2.0_rec |Lightweight model for Tamil recognition|rec_ta_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) |
| latin_mobile_v2.0_rec | Lightweight model for latin recognition | [rec_latin_lite_train.yml](../../configs/rec/multi_language/rec_latin_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_train.tar) |
| arabic_mobile_v2.0_rec | Lightweight model for arabic recognition | [rec_arabic_lite_train.yml](../../configs/rec/multi_language/rec_arabic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_train.tar) |
| cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
For more supported languages, please refer to : [Multi-language model](./multi_languages_en.md)
<a name="Angle"></a> <a name="Angle"></a>
......
# Multi-language model
**Recent Update**
- 2021.4.9 supports the detection and recognition of 80 languages
- 2021.4.9 supports **lightweight high-precision** English model detection and recognition
PaddleOCR aims to create a rich, leading, and practical OCR tool library, which not only provides
Chinese and English models in general scenarios, but also provides models specifically trained
in English scenarios. And multilingual models covering [80 languages](#language_abbreviations).
Among them, the English model supports the detection and recognition of uppercase and lowercase
letters and common punctuation, and the recognition of space characters is optimized:
<div align="center">
<img src="../imgs_results/multi_lang/img_12.jpg" width="900" height="300">
</div>
The multilingual models cover Latin, Arabic, Traditional Chinese, Korean, Japanese, etc.:
<div align="center">
<img src="../imgs_results/multi_lang/japan_2.jpg" width="600" height="300">
<img src="../imgs_results/multi_lang/french_0.jpg" width="300" height="300">
<img src="../imgs_results/multi_lang/korean_0.jpg" width="500" height="300">
<img src="../imgs_results/multi_lang/arabic_0.jpg" width="300" height="300">
</div>
This document will briefly introduce how to use the multilingual model.
- [1 Installation](#Install)
- [1.1 paddle installation](#paddleinstallation)
- [1.2 paddleocr package installation](#paddleocr_package_install)
- [2 Quick Use](#Quick_Use)
- [2.1 Command line operation](#Command_line_operation)
- [2.2 python script running](#python_Script_running)
- [3 Custom Training](#Custom_Training)
- [4 Inference and Deployment](#inference)
- [4 Supported languages and abbreviations](#language_abbreviations)
<a name="Install"></a>
## 1 Installation
<a name="paddle_install"></a>
### 1.1 paddle installation
```
# cpu
pip install paddlepaddle
# gpu
pip install paddlepaddle-gpu
```
<a name="paddleocr_package_install"></a>
### 1.2 paddleocr package installation
pip install
```
pip install "paddleocr>=2.0.6" # 2.0.6 version is recommended
```
Build and install locally
```
python3 setup.py bdist_wheel
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x is the version number of paddleocr
```
<a name="Quick_use"></a>
## 2 Quick use
<a name="Command_line_operation"></a>
### 2.1 Command line operation
View help information
```
paddleocr -h
```
* Whole image prediction (detection + recognition)
Paddleocr currently supports 80 languages, which can be switched by modifying the --lang parameter.
The specific supported [language] (#language_abbreviations) can be viewed in the table.
``` bash
paddleocr --image_dir doc/imgs_en/254.jpg --lang=en
```
<div align="center">
<img src="../imgs_en/254.jpg" width="300" height="600">
<img src="../imgs_results/multi_lang/img_02.jpg" width="600" height="600">
</div>
The result is a list, each item contains a text box, text and recognition confidence
```text
[('PHO CAPITAL', 0.95723116), [[66.0, 50.0], [327.0, 44.0], [327.0, 76.0], [67.0, 82.0]]]
[('107 State Street', 0.96311164), [[72.0, 90.0], [451.0, 84.0], [452.0, 116.0], [73.0, 121.0]]]
[('Montpelier Vermont', 0.97389287), [[69.0, 132.0], [501.0, 126.0], [501.0, 158.0], [70.0, 164.0]]]
[('8022256183', 0.99810505), [[71.0, 175.0], [363.0, 170.0], [364.0, 202.0], [72.0, 207.0]]]
[('REG 07-24-201706:59 PM', 0.93537045), [[73.0, 299.0], [653.0, 281.0], [654.0, 318.0], [74.0, 336.0]]]
[('045555', 0.99346405), [[509.0, 331.0], [651.0, 325.0], [652.0, 356.0], [511.0, 362.0]]]
[('CT1', 0.9988654), [[535.0, 367.0], [654.0, 367.0], [654.0, 406.0], [535.0, 406.0]]]
......
```
* Recognition
```bash
paddleocr --image_dir doc/imgs_words_en/word_308.png --det false --lang=en
```
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_words_en/word_308.png)
The result is a tuple, which returns the recognition result and recognition confidence
```text
(0.99879867, 'LITTLE')
```
* Detection
```
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
```
The result is a list, each item contains only text boxes
```
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
......
```
<a name="python_script_running"></a>
### 2.2 python script running
ppocr also supports running in python scripts for easy embedding in your own code:
* Whole image prediction (detection + recognition)
```
from paddleocr import PaddleOCR, draw_ocr
# Also switch the language by modifying the lang parameter
ocr = PaddleOCR(lang="korean") # The model file will be downloaded automatically when executed for the first time
img_path ='doc/imgs/korean_1.jpg'
result = ocr.ocr(img_path)
# Recognition and detection can be performed separately through parameter control
# result = ocr.ocr(img_path, det=False) Only perform recognition
# result = ocr.ocr(img_path, rec=False) Only perform detection
# Print detection frame and recognition result
for line in result:
print(line)
# Visualization
from PIL import Image
image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result]
txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
Visualization of results:
![](https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.1/doc/imgs_results/korean.jpg)
ppocr also supports direction classification. For more usage methods, please refer to: [whl package instructions](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_ch/whl.md).
<a name="Custom_training"></a>
## 3 Custom training
ppocr supports using your own data for custom training or finetune, where the recognition model can refer to [French configuration file](../../configs/rec/multi_language/rec_french_lite_train.yml)
Modify the training data path, dictionary and other parameters.
For specific data preparation and training process, please refer to: [Text Detection](../doc_en/detection_en.md), [Text Recognition](../doc_en/recognition_en.md), more functions such as predictive deployment,
For functions such as data annotation, you can read the complete [Document Tutorial](../../README.md).
<a name="inference"></a>
## 4 Inference and Deployment
In addition to installing the whl package for quick forecasting,
ppocr also provides a variety of forecasting deployment methods.
If necessary, you can read related documents:
- [Python Inference](./inference_en.md)
- [C++ Inference](../../deploy/cpp_infer/readme_en.md)
- [Serving](../../deploy/hubserving/readme_en.md)
- [Mobile](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme_en.md)
- [Benchmark](./benchmark_en.md)
<a name="language_abbreviations"></a>
## 5 Support languages and abbreviations
| Language | Abbreviation | | Language | Abbreviation |
| --- | --- | --- | --- | --- |
|chinese and english|ch| |Arabic|ar|
|english|en| |Hindi|hi|
|french|fr| |Uyghur|ug|
|german|german| |Persian|fa|
|japan|japan| |Urdu|ur|
|korean|korean| | Serbian(latin) |rs_latin|
|chinese traditional |ch_tra| |Occitan |oc|
| Italian |it| |Marathi|mr|
|Spanish |es| |Nepali|ne|
| Portuguese|pt| |Serbian(cyrillic)|rs_cyrillic|
|Russia|ru||Bulgarian |bg|
|Ukranian|uk| |Estonian |et|
|Belarusian|be| |Irish |ga|
|Telugu |te| |Croatian |hr|
|Saudi Arabia|sa| |Hungarian |hu|
|Tamil |ta| |Indonesian|id|
|Afrikaans |af| |Icelandic|is|
|Azerbaijani |az||Kurdish|ku|
|Bosnian|bs| |Lithuanian |lt|
|Czech|cs| |Latvian |lv|
|Welsh |cy| |Maori|mi|
|Danish|da| |Malay|ms|
|Maltese |mt| |Adyghe |ady|
|Dutch |nl| |Kabardian |kbd|
|Norwegian |no| |Avar |ava|
|Polish |pl| |Dargwa |dar|
|Romanian |ro| |Ingush |inh|
|Slovak |sk| |Lak |lbe|
|Slovenian |sl| |Lezghian |lez|
|Albanian |sq| |Tabassaran |tab|
|Swedish |sv| |Bihari |bh|
|Swahili |sw| |Maithili |mai|
|Tagalog |tl| |Angika |ang|
|Turkish |tr| |Bhojpuri |bho|
|Uzbek |uz| |Magahi |mah|
|Vietnamese |vi| |Nagpur |sck|
|Mongolian |mn| |Newari |new|
|Abaza |abq| |Goan Konkani|gom|
# End-to-end OCR Algorithm-PGNet
- [1. Brief Introduction](#Brief_Introduction)
- [2. Environment Configuration](#Environment_Configuration)
- [3. Quick Use](#Quick_Use)
- [4. Model Training,Evaluation And Inference](#Model_Training_Evaluation_And_Inference)
<a name="Brief_Introduction"></a>
## 1. Brief Introduction
OCR algorithm can be divided into two-stage algorithm and end-to-end algorithm. The two-stage OCR algorithm is generally divided into two parts, text detection and text recognition algorithm. The text detection algorithm gets the detection box of the text line from the image, and then the recognition algorithm identifies the content of the text box. The end-to-end OCR algorithm can complete text detection and recognition in one algorithm. Its basic idea is to design a model with both detection unit and recognition module, share the CNN features of both and train them together. Because one algorithm can complete character recognition, the end-to-end model is smaller and faster.
### Introduction Of PGNet Algorithm
In recent years, the end-to-end OCR algorithm has been well developed, including MaskTextSpotter series, TextSnake, TextDragon, PGNet series and so on. Among these algorithms, PGNet algorithm has the advantages that other algorithms do not
- Pgnet loss is designed to guide training, and no character-level annotations is needed
- NMS and ROI related operations are not needed, It can accelerate the prediction
- The reading order prediction module is proposed
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition
- Higher accuracy and faster prediction speed
For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows:
![](../pgnet_framework.png)
After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
The results of detection and recognition are as follows:
![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png)
### Performance
####Test set: Total Text
####Test environment: NVIDIA Tesla V100-SXM2-16GB
|PGNetA|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|download|
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
|Paper|85.30|86.80|86.1|-|-|61.7|38.20 (size=640)|-|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|48.73 (size=768)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)|
*note:PGNet in PaddleOCR optimizes the prediction speed, and can significantly improve the end-to-end prediction speed within the acceptable range of accuracy reduction*
<a name="Environment_Configuration"></a>
## 2. Environment Configuration
Please refer to [Quick Installation](./installation_en.md) Configure the PaddleOCR running environment.
<a name="Quick_Use"></a>
## 3. Quick Use
### inference model download
This section takes the trained end-to-end model as an example to quickly use the model prediction. First, download the trained end-to-end inference model [download address](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
```
mkdir inference && cd inference
# Download the English end-to-end model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar
```
* In Windows environment, if 'wget' is not installed, the link can be copied to the browser when downloading the model, and decompressed and placed in the corresponding directory
After decompression, there should be the following file structure:
```
├── e2e_server_pgnetA_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
### Single image or image set prediction
```bash
# Prediction single image specified by image_dir
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# Prediction the collection of images specified by image_dir
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
# If you want to use CPU for prediction, you need to set use_gpu parameter is false
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
```
### Visualization results
The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="Model_Training_Evaluation_And_Inference"></a>
## 4. Model Training,Evaluation And Inference
This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR.
### Data Preparation
Download and unzip [totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) dataset to PaddleOCR/train_data/, dataset organization structure is as follow:
```
/PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text training data of dataset
|- img11.png
| ...
|- train.txt # total_text training annotation of dataset
```
total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t":
```
" Image file name Image annotation information encoded by json.dumps"
rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
```
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner.
`transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.**
If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format.
### Start Training
PGNet training is divided into two steps: Step 1: training on the synthetic data to get the pretrain_model, and the accuracy of the model is still low; step 2: loading the pretrain_model and training on the totaltext data set; for fast training, we directly provide the pre training model of step 1[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar).
```shell
cd PaddleOCR/
download step1 pretrain_models
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar
You can get the following file format
./pretrain_models/train_step1/
└─ best_accuracy.pdopt
└─ best_accuracy.states
└─ best_accuracy.pdparams
```
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
```shell
# single GPU training
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
# multi-GPU training
# Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
```
In the above instruction, use `-c` to select the training to use the `configs/e2e/e2e_r50_vd_pg.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
```
#### Load trained model and continue training
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
```shell
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
```
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
PaddleOCR calculates three indicators for evaluating performance of OCR end-to-end task: Precision, Recall, and Hmean.
Run the following code to calculate the evaluation indicators. The result will be saved in the test result file specified by `save_res_path` in the configuration file `e2e_r50_vd_pg.yml`
When evaluating, set post-processing parameters `max_side_len=768`. If you use different datasets, different models for training.
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
```shell
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
```
### Model Test
Test the end-to-end result on a single image:
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
```
Test the end-to-end result on all images in the folder:
```shell
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
```
### Model inference
#### (1).Quadrangle text detection model (ICDAR2015)
First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert:
```
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). Curved text detection model (Total-Text)
For the curved text example, we use the same model as the quadrilateral
**For PGNet end-to-end curved text detection model inference, you need to set the parameter `--e2e_algorithm="PGNet"` and `--e2e_pgnet_polygon=True`**, run the following command:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
...@@ -131,7 +131,7 @@ PaddleOCR has built-in dictionaries, which can be used on demand. ...@@ -131,7 +131,7 @@ PaddleOCR has built-in dictionaries, which can be used on demand.
`ppocr/utils/dict/german_dict.txt` is a German dictionary with 131 characters `ppocr/utils/dict/german_dict.txt` is a German dictionary with 131 characters
`ppocr/utils/dict/en_dict.txt` is a English dictionary with 63 characters `ppocr/utils/en_dict.txt` is a English dictionary with 96 characters
The current multi-language model is still in the demo stage and will continue to optimize the model and add languages. **You are very welcome to provide us with dictionaries and fonts in other languages**, The current multi-language model is still in the demo stage and will continue to optimize the model and add languages. **You are very welcome to provide us with dictionaries and fonts in other languages**,
...@@ -279,7 +279,7 @@ Eval: ...@@ -279,7 +279,7 @@ Eval:
<a name="Multi_language"></a> <a name="Multi_language"></a>
#### 2.3 Multi-language #### 2.3 Multi-language
PaddleOCR currently supports 26 (except Chinese) language recognition. A multi-language configuration file template is PaddleOCR currently supports 80 (except Chinese) language recognition. A multi-language configuration file template is
provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml) provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)
There are two ways to create the required configuration file:: There are two ways to create the required configuration file::
...@@ -368,27 +368,12 @@ Currently, the multi-language algorithms supported by PaddleOCR are: ...@@ -368,27 +368,12 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german | | rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan | | rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean | | rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Italian | it | | rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Spanish | xi | | rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Portuguese | pu | | rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Russia | ru | | rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Arabic | ar |
| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Hindi | hi |
| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Uyghur | ug |
| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Persian(Farsi) | fa |
| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Urdu | ur |
| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(latin) | rs |
| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Occitan | oc |
| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Marathi | mr |
| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Nepali | ne |
| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(cyrillic) | rsc |
| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Bulgarian | bg |
| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Ukranian | uk |
| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Belarusian | be |
| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Telugu | te |
| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Kannada | ka |
| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Tamil | ta |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi. The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
......
...@@ -35,7 +35,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -35,7 +35,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -69,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -116,7 +116,7 @@ for line in result: ...@@ -116,7 +116,7 @@ for line in result:
from PIL import Image from PIL import Image
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -262,7 +262,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -262,7 +262,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -292,7 +292,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -292,7 +292,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
...@@ -320,7 +320,7 @@ image = Image.open(img_path).convert('RGB') ...@@ -320,7 +320,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
......
doc/imgs_results/whl/12_det.jpg

166.3 KB | W: | H:

doc/imgs_results/whl/12_det.jpg

409.6 KB | W: | H:

doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
doc/imgs_results/whl/12_det.jpg
  • 2-up
  • Swipe
  • Onion skin
doc/joinus.PNG

109.2 KB | W: | H:

doc/joinus.PNG

108.0 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -30,12 +30,17 @@ from ppocr.utils.logging import get_logger ...@@ -30,12 +30,17 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
model_urls = { model_urls = {
'det': 'det': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', 'ch':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
'en':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
},
'rec': { 'rec': {
'ch': { 'ch': {
'url': 'url':
...@@ -45,7 +50,7 @@ model_urls = { ...@@ -45,7 +50,7 @@ model_urls = {
'en': { 'en': {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/en_dict.txt' 'dict_path': './ppocr/utils/en_dict.txt'
}, },
'french': { 'french': {
'url': 'url':
...@@ -66,6 +71,46 @@ model_urls = { ...@@ -66,6 +71,46 @@ model_urls = {
'url': 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt' 'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
} }
}, },
'cls': 'cls':
...@@ -73,7 +118,7 @@ model_urls = { ...@@ -73,7 +118,7 @@ model_urls = {
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = 2.0 VERSION = '2.1'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
...@@ -148,6 +193,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -148,6 +193,7 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
...@@ -159,7 +205,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -159,7 +205,7 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--rec_model_dir", type=str, default=None) parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30) parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None) parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
...@@ -169,7 +215,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -169,7 +215,7 @@ def parse_args(mMain=True, add_help=True):
parser.add_argument("--cls_model_dir", type=str, default=None) parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30) parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9) parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False) parser.add_argument("--enable_mkldnn", type=bool, default=False)
...@@ -196,6 +242,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -196,6 +242,7 @@ def parse_args(mMain=True, add_help=True):
det_db_box_thresh=0.5, det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6, det_db_unclip_ratio=1.6,
use_dilation=False, use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8, det_east_score_thresh=0.8,
det_east_cover_thresh=0.1, det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2, det_east_nms_thresh=0.2,
...@@ -203,7 +250,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -203,7 +250,7 @@ def parse_args(mMain=True, add_help=True):
rec_model_dir=None, rec_model_dir=None,
rec_image_shape="3, 32, 320", rec_image_shape="3, 32, 320",
rec_char_type='ch', rec_char_type='ch',
rec_batch_num=30, rec_batch_num=6,
max_text_length=25, max_text_length=25,
rec_char_dict_path=None, rec_char_dict_path=None,
use_space_char=True, use_space_char=True,
...@@ -211,7 +258,7 @@ def parse_args(mMain=True, add_help=True): ...@@ -211,7 +258,7 @@ def parse_args(mMain=True, add_help=True):
cls_model_dir=None, cls_model_dir=None,
cls_image_shape="3, 48, 192", cls_image_shape="3, 48, 192",
label_list=['0', '180'], label_list=['0', '180'],
cls_batch_num=30, cls_batch_num=6,
cls_thresh=0.9, cls_thresh=0.9,
enable_mkldnn=False, enable_mkldnn=False,
use_zero_copy_run=False, use_zero_copy_run=False,
...@@ -233,9 +280,36 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -233,9 +280,36 @@ class PaddleOCR(predict_system.TextSystem):
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang lang = postprocess_params.lang
latin_lang = [
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
'gom', 'sa', 'bgc'
]
if lang in latin_lang:
lang = "latin"
elif lang in arabic_lang:
lang = "arabic"
elif lang in cyrillic_lang:
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
assert lang in model_urls[ assert lang in model_urls[
'rec'], 'param lang must in {}, but got {}'.format( 'rec'], 'param lang must in {}, but got {}'.format(
model_urls['rec'].keys(), lang) model_urls['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
else:
det_lang = "en"
use_inner_dict = False use_inner_dict = False
if postprocess_params.rec_char_dict_path is None: if postprocess_params.rec_char_dict_path is None:
use_inner_dict = True use_inner_dict = True
...@@ -244,17 +318,17 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -244,17 +318,17 @@ class PaddleOCR(predict_system.TextSystem):
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join( postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
BASE_DIR, '{}/det'.format(VERSION)) 'det', det_lang)
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join( postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
BASE_DIR, '{}/rec/{}'.format(VERSION, lang)) 'rec', lang)
if postprocess_params.cls_model_dir is None: if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join( postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
BASE_DIR, '{}/cls'.format(VERSION))
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, model_urls['det']) maybe_download(postprocess_params.det_model_dir,
model_urls['det'][det_lang])
maybe_download(postprocess_params.rec_model_dir, maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url']) model_urls['rec'][lang]['url'])
maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
......
...@@ -34,6 +34,7 @@ import paddle.distributed as dist ...@@ -34,6 +34,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDataSet from ppocr.data.lmdb_dataset import LMDBDataSet
from ppocr.data.pgnet_dataset import PGDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) ...@@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDataSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))
...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else: else:
use_shared_memory = True use_shared_memory = True
if mode == "Train": if mode == "Train":
#Distribute data to multiple cards # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last) drop_last=drop_last)
else: else:
#Distribute data to single card # Distribute data to single card
batch_sampler = BatchSampler( batch_sampler = BatchSampler(
dataset=dataset, dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
......
...@@ -28,6 +28,7 @@ from .label_ops import * ...@@ -28,6 +28,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import *
def transform(data, ops=None): def transform(data, ops=None):
......
...@@ -96,7 +96,7 @@ class BaseRecLabelEncode(object): ...@@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne' 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
...@@ -187,6 +187,78 @@ class CTCLabelEncode(BaseRecLabelEncode): ...@@ -187,6 +187,78 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncodeTest,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
import json
padnum = len(self.dict)
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['ignore_tags'] = txt_tags
temp_texts = []
for text in txts:
text = text.lower()
text = self.encode(text)
if text is None:
return None
text = text + [padnum] * (self.max_text_len - len(text)
) # use 36 to pad
temp_texts.append(text)
data['texts'] = np.array(temp_texts)
return data
class E2ELabelEncodeTrain(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
import json
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data
class AttnLabelEncode(BaseRecLabelEncode): class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -197,7 +197,6 @@ class DetResizeForTest(object): ...@@ -197,7 +197,6 @@ class DetResizeForTest(object):
sys.exit(0) sys.exit(0)
ratio_h = resize_h / float(h) ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
# return img, np.array([h, w])
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
def resize_image_type2(self, img): def resize_image_type2(self, img):
...@@ -206,7 +205,6 @@ class DetResizeForTest(object): ...@@ -206,7 +205,6 @@ class DetResizeForTest(object):
resize_w = w resize_w = w
resize_h = h resize_h = h
# Fix the longer side
if resize_h > resize_w: if resize_h > resize_w:
ratio = float(self.resize_long) / resize_h ratio = float(self.resize_long) / resize_h
else: else:
...@@ -223,3 +221,72 @@ class DetResizeForTest(object): ...@@ -223,3 +221,72 @@ class DetResizeForTest(object):
ratio_w = resize_w / float(w) ratio_w = resize_w / float(w)
return img, [ratio_h, ratio_w] return img, [ratio_h, ratio_w]
class E2EResizeForTest(object):
def __init__(self, **kwargs):
super(E2EResizeForTest, self).__init__()
self.max_side_len = kwargs['max_side_len']
self.valid_set = kwargs['valid_set']
def __call__(self, data):
img = data['image']
src_h, src_w, _ = img.shape
if self.valid_set == 'totaltext':
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
img, max_side_len=self.max_side_len)
else:
im_resized, (ratio_h, ratio_w) = self.resize_image(
img, max_side_len=self.max_side_len)
data['image'] = im_resized
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
return data
def resize_image_for_totaltext(self, im, max_side_len=512):
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image(self, im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
__all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
character_dict_path,
max_text_length,
max_text_nums,
tcl_len,
batch_size=14,
min_crop_size=24,
min_text_size=4,
max_text_size=512,
**kwargs):
self.tcl_len = tcl_len
self.max_text_length = max_text_length
self.max_text_nums = max_text_nums
self.batch_size = batch_size
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table)
self.img_id = 0
def get_dict(self, character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def quad_area(self, poly):
"""
compute area of a polygon
:param poly:
:return:
"""
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad
def check_and_validate_polys(self, polys, tags, im_size):
"""
check so that the text poly is in the same direction,
and also filter some invalid polygons
:param polys:
:param tags:
:return:
"""
(h, w) = im_size
if polys.shape[0] == 0:
return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
validated_polys = []
validated_tags = []
hv_tags = []
for poly, tag in zip(polys, tags):
quad = self.gen_quad_from_poly(poly)
p_area = self.quad_area(quad)
if abs(p_area) < 1:
print('invalid poly')
continue
if p_area > 0:
if tag == False:
print('poly in wrong direction')
tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
1), :]
quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
quad[2])
hv_tag = 1
if len_w * 2.0 < len_h:
hv_tag = 0
validated_polys.append(poly)
validated_tags.append(tag)
hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(
hv_tags)
def crop_area(self,
im,
polys,
tags,
hv_tags,
txts,
crop_background=False,
max_tries=25):
"""
make random crop from the input image
:param im:
:param polys: [b,4,2]
:param tags:
:param crop_background:
:param max_tries: 50 -> 25
:return:
"""
h, w, _ = im.shape
pad_h = h // 10
pad_w = w // 10
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
for poly in polys:
poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0])
w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1])
h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return im, polys, tags, hv_tags, txts
for i in range(max_tries):
xx = np.random.choice(w_axis, size=2)
xmin = np.min(xx) - pad_w
xmax = np.max(xx) - pad_w
xmin = np.clip(xmin, 0, w - 1)
xmax = np.clip(xmax, 0, w - 1)
yy = np.random.choice(h_axis, size=2)
ymin = np.min(yy) - pad_h
ymax = np.max(yy) - pad_h
ymin = np.clip(ymin, 0, h - 1)
ymax = np.clip(ymax, 0, h - 1)
if xmax - xmin < self.min_crop_size or \
ymax - ymin < self.min_crop_size:
continue
if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else:
selected_polys = []
if len(selected_polys) == 0:
# no text in this area
if crop_background:
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
else:
continue
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys]
tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys]
txts_tmp = []
for selected_poly in selected_polys:
txts_tmp.append(txts[selected_poly])
txts = txts_tmp
polys[:, :, 0] -= xmin
polys[:, :, 1] -= ymin
return im, polys, tags, hv_tags, txts
return im, polys, tags, hv_tags, txts
def fit_and_gather_tcl_points_v2(self,
min_area_quad,
poly,
max_h,
max_w,
fixed_point_num=64,
img_id=0,
reference_height=3):
"""
Find the center point of poly as key_points, then fit and gather.
"""
key_point_xys = []
point_num = poly.shape[0]
for idx in range(point_num // 2):
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
key_point_xys.append(center_point)
tmp_image = np.zeros(
shape=(
max_h,
max_w, ), dtype='float32')
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
False, 1.0)
ys, xs = np.where(tmp_image > 0)
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
left_center_pt = (
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
right_center_pt = (
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / (
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_unit_vec_tile = np.tile(proj_unit_vec,
(xy_text.shape[0], 1)) # (n, 2)
left_center_pt_tile = np.tile(left_center_pt,
(xy_text.shape[0], 1)) # (n, 2)
xy_text_to_left_center = xy_text - left_center_pt_tile
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
# convert to np and keep the num of point not greater then fixed_point_num
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
keep_ids = [
int((point_num * 1.0 / fixed_point_num) * x)
for x in range(fixed_point_num)
]
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
if np.random.rand() < 0.2 and reference_height >= 3:
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
[keep, 1])
pos_info += random_float
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map):
"""
"""
width_list = []
height_list = []
for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w)
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
k = 1
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direct_vector = direct_vector_full / (
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(
map(float,
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
k += 1
return direction_map
def calculate_average_height(self, poly_quads):
"""
"""
height_list = []
for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height
def generate_tcl_ctc_label(self,
h,
w,
polys,
tags,
text_strs,
ds_ratio,
tcl_ratio=0.3,
shrink_ratio_of_width=0.15):
"""
Generate polygon.
"""
score_map_big = np.zeros(
(
h,
w, ), dtype=np.float32)
h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio
score_map = np.zeros(
(
h,
w, ), dtype=np.float32)
score_label_map = np.zeros(
(
h,
w, ), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones(
(
h,
w, ), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
[1, 1, 3]).astype(np.float32)
label_idx = 0
score_label_map_text_label_list = []
pos_list, pos_mask, label_list = [], [], []
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0]
tag = poly_tag[1]
# generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
continue
if tag:
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label,
self.Lexicon_Table)
text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label
if c_ in self.Lexicon_Table]
if len(text_label_index_list) < 1:
continue
tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly)
stcl_quads, quad_index = self.shrink_poly_along_width(
tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0)
cv2.fillPoly(score_map_big,
np.round(stcl_quads / ds_ratio).astype(np.int32),
1.0)
for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(
quad_mask,
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
quad_mask, tbo_map)
# score label map and score_label_map_text_label_list for refine
if label_idx == 0:
text_pos_list_ = [[len(self.Lexicon_Table)], ]
score_label_map_text_label_list.append(text_pos_list_)
label_idx += 1
cv2.fillPoly(score_label_map,
np.round(poly_quads).astype(np.int32), label_idx)
score_label_map_text_label_list.append(text_label_index_list)
# direction info, fix-me
n_char = len(text_label_index_list)
direction_map = self.generate_direction_map(poly_quads, n_char,
direction_map)
# pos info
average_shrink_height = self.calculate_average_height(
stcl_quads)
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad,
poly,
max_h=h,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
label_l = text_label_index_list
if len(text_label_index_list) < 2:
continue
pos_list.append(pos_l)
pos_mask.append(pos_m)
label_list.append(label_l)
# use big score_map for smooth tcl lines
score_map_big_resized = cv2.resize(
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label_list
def adjust_point(self, poly):
"""
adjust point order.
"""
point_num = poly.shape[0]
if point_num == 4:
len_1 = np.linalg.norm(poly[0] - poly[1])
len_2 = np.linalg.norm(poly[1] - poly[2])
len_3 = np.linalg.norm(poly[2] - poly[3])
len_4 = np.linalg.norm(poly[3] - poly[0])
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
poly = poly[[1, 2, 3, 0], :]
elif point_num > 4:
vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi):
index = list(range(1, point_num)) + [0]
poly = poly[np.array(index), :]
return poly
def gen_min_area_quad_from_poly(self, poly):
"""
Generate min area quad from poly.
"""
point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32)
if point_num == 4:
min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4
else:
rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0]
box = np.array(cv2.boxPoints(rect))
first_point_idx = 0
min_dist = 1e4
for i in range(4):
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
if dist < min_dist:
min_dist = dist
first_point_idx = i
for i in range(4):
min_area_quad[i] = box[(first_point_idx + i) % 4]
return min_area_quad, center_point
def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self,
quads,
shrink_ratio_of_width,
expand_height_ratio=1.0):
"""
shrink poly with given length.
"""
upper_edge_list = []
def get_cut_info(edge_len_list, cut_len):
for idx, edge_len in enumerate(edge_len_list):
cut_len -= edge_len
if cut_len <= 0.000001:
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
return idx, ratio
for quad in quads:
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
upper_edge_list.append(upper_edge_len)
# length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][
3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
2]) * expand_height_ratio
shrink_length = min(left_length, right_length,
sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length
upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = []
if left_idx == right_idx:
out_quad_list.append(
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else:
out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx):
out_quad_list.append(quads[idx])
out_quad_list.append(right_quad)
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
def prepare_text_label(self, label_str, Lexicon_Table):
"""
Prepare text lablel by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
return label_str.lower()
else:
return label_str
def vector_angle(self, A, B):
"""
Calculate the angle between vector AB and x-axis positive direction.
"""
AB = np.array([B[1] - A[1], B[0] - A[0]])
return np.arctan2(*AB)
def theta_line_cross_point(self, theta, point):
"""
Calculate the line through given point and angle in ax + by + c =0 form.
"""
x, y = point
cos = np.cos(theta)
sin = np.sin(theta)
return [sin, -cos, cos * y - sin * x]
def line_cross_two_point(self, A, B):
"""
Calculate the line through given point A and B in ax + by + c =0 form.
"""
angle = self.vector_angle(A, B)
return self.theta_line_cross_point(angle, A)
def average_angle(self, poly):
"""
Calculate the average angle between left and right edge in given poly.
"""
p0, p1, p2, p3 = poly
angle30 = self.vector_angle(p3, p0)
angle21 = self.vector_angle(p2, p1)
return (angle30 + angle21) / 2
def line_cross_point(self, line1, line2):
"""
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
"""
a1, b1, c1 = line1
a2, b2, c2 = line2
d = a1 * b2 - a2 * b1
if d == 0:
print('Cross point does not exist')
return np.array([0, 0], dtype=np.float32)
else:
x = (b1 * c2 - b2 * c1) / d
y = (a2 * c1 - a1 * c2) / d
return np.array([x, y], dtype=np.float32)
def quad2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point. (4, 2)
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
def poly2tcl(self, poly, ratio):
"""
Generate center line by poly clock-wise point.
"""
ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0]
for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
) * ratio_pair
tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
"""
Generate tbo_map for give quad.
"""
# upper and lower line function: ax + by + c = 0;
up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line.
angle = self.average_angle(quad)
xy_in_poly = np.argwhere(tcl_mask == 1)
for y, x in xy_in_poly:
point = (x, y)
line = self.theta_line_cross_point(angle, point)
cross_point_upper = self.line_cross_point(up_line, line)
cross_point_lower = self.line_cross_point(lower_line, line)
##FIX, offset reverse
upper_offset_x, upper_offset_y = cross_point_upper - point
lower_offset_x, lower_offset_y = cross_point_lower - point
tbo_map[y, x, 0] = upper_offset_y
tbo_map[y, x, 1] = upper_offset_x
tbo_map[y, x, 2] = lower_offset_y
tbo_map[y, x, 3] = lower_offset_x
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
return tbo_map
def poly2quads(self, poly):
"""
Split poly into quads.
"""
quad_list = []
point_num = poly.shape[0]
# point pair
point_pair_list = []
for idx in range(point_num // 2):
point_pair = [poly[idx], poly[point_num - 1 - idx]]
point_pair_list.append(point_pair)
quad_num = point_num // 2 - 1
for idx in range(quad_num):
# reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list)
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
if rand_degree_ratio > 0.5:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4): # 16->4
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
return dst_im, np.array(dst_polys, dtype=np.float32)
def __call__(self, data):
input_size = 512
im = data['image']
text_polys = data['polys']
text_tags = data['ignore_tags']
text_strs = data['texts']
h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
if text_polys.shape[0] <= 0:
return None
# set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales)
if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale
asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
text_polys *= rd_scale
h, w, _ = im.shape
if min(h, w) < 16:
return None
# no background
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
im,
text_polys,
text_tags,
hv_tags,
text_strs,
crop_background=False)
if text_polys.shape[0] == 0:
return None
# # continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
if (new_h is None) or (new_w is None):
return None
# resize image
std_ratio = float(input_size) / max(new_w, new_h)
rand_scales = np.array(
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale
# add gaussian blur
if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1
ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
# add brighter
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# add darker
if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape
if min(new_w, new_h) < input_size * 0.5:
return None
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255
# Random the start position
del_h = input_size - new_h
del_w = input_size - new_w
sh, sw = 0, 0
if del_h > 1:
sh = int(np.random.rand() * del_h)
if del_w > 1:
sw = int(np.random.rand() * del_w)
# Padding
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh
score_map, score_label_map, border_map, direction_map, training_mask, \
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
input_size,
text_polys,
text_tags,
text_strs, 0.25)
if len(label_list) <= 0: # eliminate negative samples
return None
pos_list_temp = np.zeros([64, 3])
pos_mask_temp = np.zeros([64, 1])
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
for i, label in enumerate(label_list):
n = len(label)
if n > self.max_text_length:
label_list[i] = label[:self.max_text_length]
continue
while n < self.max_text_length:
label.append([self.pad_num])
n += 1
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
return None
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
pos_mask.append(pos_mask_temp)
label_list.append(label_list_temp)
if self.img_id == self.batch_size - 1:
self.img_id = 0
else:
self.img_id += 1
im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1))
images = im_padded[::-1, :, :]
tcl_maps = score_map[np.newaxis, :, :]
tcl_label_maps = score_label_map[np.newaxis, :, :]
border_maps = border_map.transpose((2, 0, 1))
direction_maps = direction_map.transpose((2, 0, 1))
training_masks = training_mask[np.newaxis, :, :]
pos_list = np.array(pos_list)
pos_mask = np.array(pos_mask)
label_list = np.array(label_list)
data['images'] = images
data['tcl_maps'] = tcl_maps
data['tcl_label_maps'] = tcl_label_maps
data['border_maps'] = border_maps
data['direction_maps'] = direction_maps
data['training_masks'] = training_masks
data['label_list'] = label_list
data['pos_list'] = pos_list
data['pos_mask'] = pos_mask
return data
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import os
from paddle.io import Dataset
from .imaug import transform, create_operators
import random
class PGDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
super(PGDataSet, self).__init__()
self.logger = logger
self.seed = seed
self.mode = mode
global_config = config['Global']
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
random.shuffle(self.data_lines)
return
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx]
img_id = 0
try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
try:
img_id = int(data_line.split(".")[0][7:])
except:
img_id = 0
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
self.data_idx_order_list[idx], e))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
return outs
def __len__(self):
return len(self.data_idx_order_list)
...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset): ...@@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
self.mode = mode.lower()
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset): ...@@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
...@@ -56,16 +57,16 @@ class SimpleDataSet(Dataset): ...@@ -56,16 +57,16 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
random.seed(self.seed) if self.mode == "train" or ratio_list[idx] < 1.0:
lines = random.sample(lines, random.seed(self.seed)
round(len(lines) * ratio_list[idx])) lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines) data_lines.extend(lines)
return data_lines return data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: random.seed(self.seed)
random.seed(self.seed) random.shuffle(self.data_lines)
random.shuffle(self.data_lines)
return return
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset): ...@@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
data_line, e)) data_line, e))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) # during evaluation, we should fix the idx to get same results for many times of evaluation.
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
......
...@@ -29,10 +29,11 @@ def build_loss(config): ...@@ -29,10 +29,11 @@ def build_loss(config):
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss' 'SRNLoss', 'PGLoss']
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
import paddle
from .det_basic_loss import DiceLoss
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
class PGLoss(nn.Layer):
def __init__(self,
tcl_bs,
max_text_length,
max_text_nums,
pad_num,
eps=1e-6,
**kwargs):
super(PGLoss, self).__init__()
self.tcl_bs = tcl_bs
self.max_text_nums = max_text_nums
self.max_text_length = max_text_length
self.pad_num = pad_num
self.dice_loss = DiceLoss(eps=eps)
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
f_border_split = f_border
b, c, h, w = l_border_norm.shape
l_border_norm_split = paddle.expand(
x=l_border_norm, shape=[b, 4 * c, h, w])
b, c, h, w = l_score.shape
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
b, c, h, w = l_mask.shape
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
border_diff = l_border_split - f_border_split
abs_border_diff = paddle.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = paddle.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
return border_loss
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
l_direction_split, l_direction_norm = paddle.tensor.split(
l_direction, num_or_sections=[2, 1], axis=1)
f_direction_split = f_direction
b, c, h, w = l_direction_norm.shape
l_direction_norm_split = paddle.expand(
x=l_direction_norm, shape=[b, 2 * c, h, w])
b, c, h, w = l_score.shape
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
b, c, h, w = l_mask.shape
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
direction_diff = l_direction_split - f_direction_split
abs_direction_diff = paddle.abs(direction_diff)
direction_sign = abs_direction_diff < 1.0
direction_sign = paddle.cast(direction_sign, dtype='float32')
direction_sign.stop_gradient = True
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
direction_out_loss = l_direction_norm_split * direction_in_loss
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
return direction_loss
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
N, B, _ = f_tcl_char_ld.shape
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
cost = paddle.nn.functional.ctc_loss(
log_probs=f_tcl_char_ld,
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=self.pad_num,
reduction='none')
cost = cost.mean()
return cost
def forward(self, predicts, labels):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = pre_process(
label_list, pos_list, pos_mask, self.max_text_length,
self.max_text_nums, self.pad_num, self.tcl_bs)
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
predicts['f_char']
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
losses = {
'loss': loss_all,
"score_loss": score_loss,
"border_loss": border_loss,
"direction_loss": direction_loss,
"ctc_loss": ctc_loss
}
return losses
...@@ -26,8 +26,9 @@ def build_metric(config): ...@@ -26,8 +26,9 @@ def build_metric(config):
from .det_metric import DetMetric from .det_metric import DetMetric
from .rec_metric import RecMetric from .rec_metric import RecMetric
from .cls_metric import ClsMetric from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
class E2EMetric(object):
def __init__(self,
mode,
gt_mat_dir,
character_dict_path,
main_indicator='f_score_e2e',
**kwargs):
self.mode = mode
self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list)
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
if self.mode == 'A':
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3][0]
ignore_tags_batch = batch[4]
gt_strs_batch = []
for temp_list in temp_gt_strs_batch:
t = ""
for index in temp_list:
if index < self.max_index:
t += self.label_list[index]
gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
# prepare gt
gt_info_list = [{
'points': gt_polyon,
'text': gt_str,
'ignore': ignore_tag
} for gt_polyon, gt_str, ignore_tag in
zip(gt_polyons, gt_strs, ignore_tags)]
# prepare det
e2e_info_list = [{
'points': det_polyon,
'texts': pred_str
} for det_polyon, pred_str in
zip(pred['points'], pred['texts'])]
result = get_socre_A(gt_info_list, e2e_info_list)
self.results.append(result)
else:
img_id = batch[5][0]
e2e_info_list = [{
'points': det_polyon,
'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result)
def get_metric(self):
metircs = combine_results(self.results)
self.reset()
return metircs
def reset(self):
self.results = [] # clear results
...@@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object): ...@@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
pairs.append({'gt': gtNum, 'det': detNum}) pairs.append({'gt': gtNum, 'det': detNum})
detMatchedNums.append(detNum) detMatchedNums.append(detNum)
evaluationLog += "Match GT #" + \ evaluationLog += "Match GT #" + \
str(gtNum) + " with Det #" + str(detNum) + "\n" str(gtNum) + " with Det #" + str(detNum) + "\n"
numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
numDetCare = (len(detPols) - len(detDontCarePolsNum)) numDetCare = (len(detPols) - len(detDontCarePolsNum))
...@@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object): ...@@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
hmean = 0 if (precision + recall) == 0 else 2.0 * \ hmean = 0 if (precision + recall) == 0 else 2.0 * \
precision * recall / (precision + recall) precision * recall / (precision + recall)
matchedSum += detMatched matchedSum += detMatched
numGlobalCareGt += numGtCare numGlobalCareGt += numGtCare
...@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object): ...@@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
methodPrecision = 0 if numGlobalCareDet == 0 else float( methodPrecision = 0 if numGlobalCareDet == 0 else float(
matchedSum) / numGlobalCareDet matchedSum) / numGlobalCareDet
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / (methodRecall + methodPrecision) methodRecall * methodPrecision / (
methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean) # print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1) # sys.exit(-1)
methodMetrics = { methodMetrics = {
......
...@@ -26,6 +26,9 @@ def build_backbone(config, model_type): ...@@ -26,6 +26,9 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet
support_dict = ['ResNet']
else: else:
raise NotImplementedError raise NotImplementedError
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ["ResNet"]
class ConvBNLayer(nn.Layer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None, ):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=stride,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, in_channels=3, layers=50, **kwargs):
super(ResNet, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
depth = [3, 4, 6, 3, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024,
2048] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512, 512]
self.conv1_1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=7,
stride=2,
act='relu',
name="conv1_1")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stages = []
self.out_channels = [3, 64]
# num_filters = [64, 128, 256, 512, 512]
if layers >= 50:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(bottleneck_block)
self.out_channels.append(num_filters[block] * 4)
self.stages.append(nn.Sequential(*block_list))
else:
for block in range(len(depth)):
block_list = []
shortcut = False
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block],
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
shortcut = True
block_list.append(basic_block)
self.out_channels.append(num_filters[block])
self.stages.append(nn.Sequential(*block_list))
def forward(self, inputs):
out = [inputs]
y = self.conv1_1(inputs)
out.append(y)
y = self.pool2d_max(y)
for block in self.stages:
y = block(y)
out.append(y)
return out
...@@ -249,7 +249,7 @@ class ResNet(nn.Layer): ...@@ -249,7 +249,7 @@ class ResNet(nn.Layer):
name=conv_name)) name=conv_name))
shortcut = True shortcut = True
self.block_list.append(bottleneck_block) self.block_list.append(bottleneck_block)
self.out_channels = num_filters[block] self.out_channels = num_filters[block] * 4
else: else:
for block in range(len(depth)): for block in range(len(depth)):
shortcut = False shortcut = False
......
...@@ -20,6 +20,7 @@ def build_head(config): ...@@ -20,6 +20,7 @@ def build_head(config):
from .det_db_head import DBHead from .det_db_head import DBHead
from .det_east_head import EASTHead from .det_east_head import EASTHead
from .det_sast_head import SASTHead from .det_sast_head import SASTHead
from .e2e_pg_head import PGHead
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
...@@ -30,8 +31,8 @@ def build_head(config): ...@@ -30,8 +31,8 @@ def build_head(config):
from .cls_head import ClsHead from .cls_head import ClsHead
support_dict = [ support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead' 'SRNHead', 'PGHead']
]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception('head only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=1,
if_act=True,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class PGHead(nn.Layer):
"""
"""
def __init__(self, in_channels, **kwargs):
super(PGHead, self).__init__()
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(1))
self.conv_f_score2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_score{}".format(2))
self.conv_f_score3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_score{}".format(3))
self.conv1 = nn.Conv2D(
in_channels=128,
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
bias_attr=False)
self.conv_f_boder1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(1))
self.conv_f_boder2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_boder{}".format(2))
self.conv_f_boder3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_boder{}".format(3))
self.conv2 = nn.Conv2D(
in_channels=128,
out_channels=4,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
bias_attr=False)
self.conv_f_char1 = ConvBNLayer(
in_channels=in_channels,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(1))
self.conv_f_char2 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(2))
self.conv_f_char3 = ConvBNLayer(
in_channels=128,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(3))
self.conv_f_char4 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_char{}".format(4))
self.conv_f_char5 = ConvBNLayer(
in_channels=256,
out_channels=256,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D(
in_channels=256,
out_channels=37,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
bias_attr=False)
self.conv_f_direc1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(1))
self.conv_f_direc2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
act='relu',
name="conv_f_direc{}".format(2))
self.conv_f_direc3 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=1,
stride=1,
padding=0,
act='relu',
name="conv_f_direc{}".format(3))
self.conv4 = nn.Conv2D(
in_channels=128,
out_channels=2,
kernel_size=3,
stride=1,
padding=1,
groups=1,
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
bias_attr=False)
def forward(self, x):
f_score = self.conv_f_score1(x)
f_score = self.conv_f_score2(f_score)
f_score = self.conv_f_score3(f_score)
f_score = self.conv1(f_score)
f_score = F.sigmoid(f_score)
# f_border
f_border = self.conv_f_boder1(x)
f_border = self.conv_f_boder2(f_border)
f_border = self.conv_f_boder3(f_border)
f_border = self.conv2(f_border)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
f_char = self.conv_f_char3(f_char)
f_char = self.conv_f_char4(f_char)
f_char = self.conv_f_char5(f_char)
f_char = self.conv3(f_char)
f_direction = self.conv_f_direc1(x)
f_direction = self.conv_f_direc2(f_direction)
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
predicts = {}
predicts['f_score'] = f_score
predicts['f_border'] = f_border
predicts['f_char'] = f_char
predicts['f_direction'] = f_direction
return predicts
...@@ -285,8 +285,7 @@ class PrePostProcessLayer(nn.Layer): ...@@ -285,8 +285,7 @@ class PrePostProcessLayer(nn.Layer):
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
"layer_norm_%d" % len( "layer_norm_%d" % len(self.sublayers()),
self.sublayers(include_sublayers=False)),
paddle.nn.LayerNorm( paddle.nn.LayerNorm(
normalized_shape=d_model, normalized_shape=d_model,
weight_attr=fluid.ParamAttr( weight_attr=fluid.ParamAttr(
...@@ -320,9 +319,7 @@ class PrepareEncoder(nn.Layer): ...@@ -320,9 +319,7 @@ class PrepareEncoder(nn.Layer):
self.src_emb_dim = src_emb_dim self.src_emb_dim = src_emb_dim
self.src_max_len = src_max_len self.src_max_len = src_max_len
self.emb = paddle.nn.Embedding( self.emb = paddle.nn.Embedding(
num_embeddings=self.src_max_len, num_embeddings=self.src_max_len, embedding_dim=self.src_emb_dim)
embedding_dim=self.src_emb_dim,
sparse=True)
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
def forward(self, src_word, src_pos): def forward(self, src_word, src_pos):
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
__all__ = ['build_neck'] __all__ = ['build_neck']
def build_neck(config): def build_neck(config):
from .db_fpn import DBFPN from .db_fpn import DBFPN
from .east_fpn import EASTFPN from .east_fpn import EASTFPN
from .sast_fpn import SASTFPN from .sast_fpn import SASTFPN
from .rnn import SequenceEncoder from .rnn import SequenceEncoder
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder'] from .pg_fpn import PGFPN
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
is_vd_mode=False,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = nn.BatchNorm(
out_channels,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance',
use_global_stats=False)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class DeConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
groups=1,
if_act=True,
act=None,
name=None):
super(DeConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.deconv = nn.Conv2DTranspose(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=act,
param_attr=ParamAttr(name="bn_" + name + "_scale"),
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
moving_mean_name="bn_" + name + "_mean",
moving_variance_name="bn_" + name + "_variance",
use_global_stats=False)
def forward(self, x):
x = self.deconv(x)
x = self.bn(x)
return x
class PGFPN(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(PGFPN, self).__init__()
num_inputs = [2048, 2048, 1024, 512, 256]
num_outputs = [256, 256, 192, 192, 128]
self.out_channels = 128
self.conv_bn_layer_1 = ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
act=None,
name='FPN_d1')
self.conv_bn_layer_2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act=None,
name='FPN_d2')
self.conv_bn_layer_3 = ConvBNLayer(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=1,
act=None,
name='FPN_d3')
self.conv_bn_layer_4 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
act=None,
name='FPN_d4')
self.conv_bn_layer_5 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d5')
self.conv_bn_layer_6 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act=None,
name='FPN_d6')
self.conv_bn_layer_7 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d7')
self.conv_bn_layer_8 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=1,
stride=1,
act=None,
name='FPN_d8')
self.conv_h0 = ConvBNLayer(
in_channels=num_inputs[0],
out_channels=num_outputs[0],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(0))
self.conv_h1 = ConvBNLayer(
in_channels=num_inputs[1],
out_channels=num_outputs[1],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(1))
self.conv_h2 = ConvBNLayer(
in_channels=num_inputs[2],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(2))
self.conv_h3 = ConvBNLayer(
in_channels=num_inputs[3],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(3))
self.conv_h4 = ConvBNLayer(
in_channels=num_inputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(4))
self.dconv0 = DeConvBNLayer(
in_channels=num_outputs[0],
out_channels=num_outputs[0 + 1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1 + 1],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2 + 1],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3 + 1],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=num_outputs[1],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=num_outputs[2],
out_channels=num_outputs[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=num_outputs[3],
out_channels=num_outputs[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def forward(self, x):
c0, c1, c2, c3, c4, c5, c6 = x
# FPN_Down_Fusion
f = [c0, c1, c2]
g = [None, None, None]
h = [None, None, None]
h[0] = self.conv_bn_layer_1(f[0])
h[1] = self.conv_bn_layer_2(f[1])
h[2] = self.conv_bn_layer_3(f[2])
g[0] = self.conv_bn_layer_4(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_bn_layer_5(g[1])
g[1] = self.conv_bn_layer_6(g[1])
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_bn_layer_7(g[2])
f_down = self.conv_bn_layer_8(g[2])
# FPN UP Fusion
f1 = [c6, c5, c4, c3, c2]
g = [None, None, None, None, None]
h = [None, None, None, None, None]
h[0] = self.conv_h0(f1[0])
h[1] = self.conv_h1(f1[1])
h[2] = self.conv_h2(f1[2])
h[3] = self.conv_h3(f1[3])
h[4] = self.conv_h4(f1[4])
g[0] = self.dconv0(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_g1(g[1])
g[1] = self.dconv1(g[1])
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_g2(g[2])
g[2] = self.dconv2(g[2])
g[3] = paddle.add(g[2], h[3])
g[3] = F.relu(g[3])
g[3] = self.conv_g3(g[3])
g[3] = self.dconv3(g[3])
g[4] = paddle.add(x=g[3], y=h[4])
g[4] = F.relu(g[4])
g[4] = self.conv_g4(g[4])
f_up = self.convf(g[4])
f_common = paddle.add(f_down, f_up)
f_common = F.relu(f_common)
return f_common
...@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None): ...@@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -34,12 +34,18 @@ class DBPostProcess(object): ...@@ -34,12 +34,18 @@ class DBPostProcess(object):
max_candidates=1000, max_candidates=1000,
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast",
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode
assert score_mode in [
"slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
...@@ -69,7 +75,10 @@ class DBPostProcess(object): ...@@ -69,7 +75,10 @@ class DBPostProcess(object):
if sside < self.min_size: if sside < self.min_size:
continue continue
points = np.array(points) points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2)) if self.score_mode == "fast":
score = self.box_score_fast(pred, points.reshape(-1, 2))
else:
score = self.box_score_slow(pred, contour)
if self.box_thresh > score: if self.box_thresh > score:
continue continue
...@@ -120,6 +129,9 @@ class DBPostProcess(object): ...@@ -120,6 +129,9 @@ class DBPostProcess(object):
return box, min(bounding_box[1]) return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box): def box_score_fast(self, bitmap, _box):
'''
box_score_fast: use bbox mean score as the mean score
'''
h, w = bitmap.shape[:2] h, w = bitmap.shape[:2]
box = _box.copy() box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
...@@ -133,6 +145,27 @@ class DBPostProcess(object): ...@@ -133,6 +145,27 @@ class DBPostProcess(object):
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def box_score_slow(self, bitmap, contour):
'''
box_score_slow: use polyon mean score as the mean score
'''
h, w = bitmap.shape[:2]
contour = contour.copy()
contour = np.reshape(contour, (-1, 2))
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
contour[:, 0] = contour[:, 0] - xmin
contour[:, 1] = contour[:, 1] - ymin
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
pred = outs_dict['maps'] pred = outs_dict['maps']
if isinstance(pred, paddle.Tensor): if isinstance(pred, paddle.Tensor):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
class PGPostProcess(object):
"""
The post process for PGNet.
"""
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
**kwargs):
self.character_dict_path = character_dict_path
self.valid_set = valid_set
self.score_thresh = score_thresh
self.mode = mode
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
self.score_thresh, outs_dict, shape_list)
if self.mode == 'fast':
data = post.pg_postprocess_fast()
else:
data = post.pg_postprocess_slow()
return data
...@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object): ...@@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN' 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
...@@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
**kwargs): **kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path, super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char) character_type, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict'] pred = preds['predict']
...@@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_idx = np.argmax(pred, axis=1) preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1) preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, 25]) preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
preds_prob = np.reshape(preds_prob, [-1, 25]) preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob) text = self.decode(preds_idx, preds_prob)
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import sys import sys
__dir__ = os.path.dirname(__file__) __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..')) sys.path.append(os.path.join(__dir__, '..'))
...@@ -49,12 +50,12 @@ class SASTPostProcess(object): ...@@ -49,12 +50,12 @@ class SASTPostProcess(object):
self.shrink_ratio_of_width = shrink_ratio_of_width self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5 # c++ la-nms is faster, but only support python 3.5
self.is_python35 = False self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5: if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True self.is_python35 = True
def point_pair2poly(self, point_pair_list): def point_pair2poly(self, point_pair_list):
""" """
Transfer vertical point_pairs into poly point in clockwise. Transfer vertical point_pairs into poly point in clockwise.
...@@ -66,31 +67,42 @@ class SASTPostProcess(object): ...@@ -66,31 +67,42 @@ class SASTPostProcess(object):
point_list[idx] = point_pair[0] point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1] point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2) return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
""" """
Generate shrink_quad_along_width. Generate shrink_quad_along_width.
""" """
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
""" """
expand poly along width. expand poly along width.
""" """
point_num = poly.shape[0] point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0) left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1], 1.0)
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32) right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \ right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio) right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
right_ratio)
poly[0] = left_quad_expand[0] poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1] poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1] poly[point_num // 2 - 1] = right_quad_expand[1]
...@@ -100,7 +112,7 @@ class SASTPostProcess(object): ...@@ -100,7 +112,7 @@ class SASTPostProcess(object):
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad.""" """Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2) xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis # Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])] xy_text = xy_text[np.argsort(xy_text[:, 1])]
...@@ -112,7 +124,7 @@ class SASTPostProcess(object): ...@@ -112,7 +124,7 @@ class SASTPostProcess(object):
point_num = int(tvo_map.shape[-1] / 2) point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4 assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2) xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map quads = xy_text_tile - tvo_map
return scores, quads, xy_text return scores, quads, xy_text
...@@ -121,14 +133,12 @@ class SASTPostProcess(object): ...@@ -121,14 +133,12 @@ class SASTPostProcess(object):
""" """
compute area of a quad. compute area of a quad.
""" """
edge = [ edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
return np.sum(edge) / 2. return np.sum(edge) / 2.
def nms(self, dets): def nms(self, dets):
if self.is_python35: if self.is_python35:
import lanms import lanms
...@@ -141,7 +151,7 @@ class SASTPostProcess(object): ...@@ -141,7 +151,7 @@ class SASTPostProcess(object):
""" """
Cluster pixels in tcl_map based on quads. Cluster pixels in tcl_map based on quads.
""" """
instance_count = quads.shape[0] + 1 # contain background instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1: if instance_count == 1:
return instance_count, instance_label_map return instance_count, instance_label_map
...@@ -149,18 +159,19 @@ class SASTPostProcess(object): ...@@ -149,18 +159,19 @@ class SASTPostProcess(object):
# predict text center # predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0] n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2) xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2) tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco pred_tc = xy_text - tco
# get gt text center # get gt text center
m = quads.shape[0] m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2) gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2) pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2) (1, m, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m) gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,) dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map return instance_count, instance_label_map
...@@ -169,26 +180,47 @@ class SASTPostProcess(object): ...@@ -169,26 +180,47 @@ class SASTPostProcess(object):
""" """
Estimate sample points number. Estimate sample points number.
""" """
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0 eh = (np.linalg.norm(quad[0] - quad[3]) +
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew)) dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num, dense_xy_center_line = xy_text[np.linspace(
endpoint=True, dtype=np.float32).astype(np.int32)] 0,
xy_text.shape[0] - 1,
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1] dense_sample_pts_num,
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1)) endpoint=True,
dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[
1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(
np.linalg.norm(
dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh)) sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h, def detect_sast(self,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0): tcl_map,
tvo_map,
tbo_map,
tco_map,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=0.3,
tcl_map_thresh=0.5,
offset_expand=1.0,
out_strid=4.0):
""" """
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
""" """
# restore quad # restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map) scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False) dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets) dets = self.nms(dets)
if dets.shape[0] == 0: if dets.shape[0] == 0:
...@@ -202,7 +234,8 @@ class SASTPostProcess(object): ...@@ -202,7 +234,8 @@ class SASTPostProcess(object):
# instance segmentation # instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map) instance_count, instance_label_map = self.cluster_by_quads_tco(
tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance. # restore single poly with tcl instance.
poly_list = [] poly_list = []
...@@ -212,10 +245,10 @@ class SASTPostProcess(object): ...@@ -212,10 +245,10 @@ class SASTPostProcess(object):
q_area = quad_areas[instance_idx - 1] q_area = quad_areas[instance_idx - 1]
if q_area < 5: if q_area < 5:
continue continue
# #
len1 = float(np.linalg.norm(quad[0] -quad[1])) len1 = float(np.linalg.norm(quad[0] - quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2])) len2 = float(np.linalg.norm(quad[1] - quad[2]))
min_len = min(len1, len2) min_len = min(len1, len2)
if min_len < 3: if min_len < 3:
continue continue
...@@ -225,16 +258,18 @@ class SASTPostProcess(object): ...@@ -225,16 +258,18 @@ class SASTPostProcess(object):
continue continue
# filter low confidence instance # filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue continue
# sort xy_text # sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0, left_center_pt = np.array(
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2) [[(quad[0, 0] + quad[-1, 0]) / 2.0,
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0, (quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2) right_center_pt = np.array(
[[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \ proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1) proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
...@@ -245,33 +280,45 @@ class SASTPostProcess(object): ...@@ -245,33 +280,45 @@ class SASTPostProcess(object):
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else: else:
sample_pts_num = self.sample_pts_num sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num, xy_center_line = xy_text[np.linspace(
endpoint=True, dtype=np.float32).astype(np.int32)] 0,
xy_text.shape[0] - 1,
sample_pts_num,
endpoint=True,
dtype=np.float32).astype(np.int32)]
point_pair_list = [] point_pair_list = []
for x, y in xy_center_line: for x, y in xy_center_line:
# get corresponding offset # get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2) offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0: if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True) offset_length = np.linalg.norm(
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0) offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal offset = offset + offset_detal
# original point # original point
ori_yx = np.array([y, x], dtype=np.float32) ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2) point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair) point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width # ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list) detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width) detected_poly = self.expand_poly_along_width(detected_poly,
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) shrink_ratio_of_width)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly) poly_list.append(detected_poly)
return poly_list return poly_list
def __call__(self, outs_dict, shape_list): def __call__(self, outs_dict, shape_list):
score_list = outs_dict['f_score'] score_list = outs_dict['f_score']
border_list = outs_dict['f_border'] border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo'] tvo_list = outs_dict['f_tvo']
...@@ -281,20 +328,28 @@ class SASTPostProcess(object): ...@@ -281,20 +328,28 @@ class SASTPostProcess(object):
border_list = border_list.numpy() border_list = border_list.numpy()
tvo_list = tvo_list.numpy() tvo_list = tvo_list.numpy()
tco_list = tco_list.numpy() tco_list = tco_list.numpy()
img_num = len(shape_list) img_num = len(shape_list)
poly_lists = [] poly_lists = []
for ino in range(img_num): for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0)) p_score = score_list[ino].transpose((1, 2, 0))
p_border = border_list[ino].transpose((1,2,0)) p_border = border_list[ino].transpose((1, 2, 0))
p_tvo = tvo_list[ino].transpose((1,2,0)) p_tvo = tvo_list[ino].transpose((1, 2, 0))
p_tco = tco_list[ino].transpose((1,2,0)) p_tco = tco_list[ino].transpose((1, 2, 0))
src_h, src_w, ratio_h, ratio_w = shape_list[ino] src_h, src_w, ratio_h, ratio_w = shape_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h, poly_list = self.detect_sast(
shrink_ratio_of_width=self.shrink_ratio_of_width, p_score,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale) p_tvo,
p_border,
p_tco,
ratio_w,
ratio_h,
src_w,
src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh,
offset_expand=self.expand_scale)
poly_lists.append({'points': np.array(poly_list)}) poly_lists.append({'points': np.array(poly_list)})
return poly_lists return poly_lists
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
ء
آ
أ
ؤ
إ
ئ
ا
ب
ة
ت
ث
ج
ح
خ
د
ذ
ر
ز
س
ش
ص
ض
ط
ظ
ع
غ
ف
ق
ك
ل
م
ن
ه
و
ى
ي
ً
ٌ
ٍ
َ
ُ
ِ
ّ
ْ
ٓ
ٔ
ٰ
ٱ
ٹ
پ
چ
ڈ
ڑ
ژ
ک
ڭ
گ
ں
ھ
ۀ
ہ
ۂ
ۃ
ۆ
ۇ
ۈ
ۋ
ی
ې
ے
ۓ
ە
١
٢
٣
٤
٥
٦
٧
٨
٩
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
Ё
Є
І
Ј
Љ
Ў
А
Б
В
Г
Д
Е
Ж
З
И
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ф
Х
Ц
Ч
Ш
Щ
Ъ
Ы
Ь
Э
Ю
Я
а
б
в
г
д
е
ж
з
и
й
к
л
м
н
о
п
р
с
т
у
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
ё
ђ
є
і
ј
љ
њ
ћ
ў
џ
Ґ
ґ
!
#
$
%
&
'
(
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
_
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
É
é
ि
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
}
¡
£
§
ª
«
­
°
²
³
´
µ
·
º
»
¿
À
Á
Â
Ä
Å
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ò
Ó
Ô
Õ
Ö
Ú
Ü
Ý
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
ą
Ć
ć
Č
č
Đ
đ
ę
ı
Ł
ł
ō
Œ
œ
Š
š
Ÿ
Ž
ž
ʒ
β
δ
ε
з
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import scipy.io as io
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
def get_socre_A(gt_dir, pred_dict):
allInputs = 1
def input_reading_mod(pred_dict):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict):
"""This helper reads groundtruths from mat files"""
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dir)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma = local_sigma_table
global_tau = local_tau_table
global_pred_str = local_pred_str
global_gt_str = local_gt_str
single_data = {}
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def get_socre_B(gt_dir, img_id, pred_dict):
allInputs = 1
def input_reading_mod(pred_dict):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dir, gt_id):
gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
gt = gt['polygt']
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
###############################################################################
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
detections = input_reading_mod(pred_dict)
groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma = local_sigma_table
global_tau = local_tau_table
global_pred_str = local_pred_str
global_gt_str = local_gt_str
single_data = {}
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def combine_results(all_data):
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
global_sigma = []
global_tau = []
global_pred_str = []
global_gt_str = []
for data in all_data:
global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'])
global_pred_str.append(data['global_pred_str'])
global_gt_str.append(data['global_gt_str'])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if ele_gt_id not in global_gt_str[idy]:
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices
:param gt_x: [1, N] Xs of groundtruth's vertices
:param gt_y: [1, N] Ys of groundtruth's vertices
##############
All the calculation of 'AREA' in this script is handled by:
1) First generating a binary mask with the polygon area filled up with 1's
2) Summing up all the 1's
"""
def area(x, y):
polygon = Polygon(np.stack([x, y], axis=1))
return float(polygon.area)
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
"""
This helper determine if both polygons are intersecting with each others with an approximation method.
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
"""
det_ymax = np.max(det_y)
det_xmax = np.max(det_x)
det_ymin = np.min(det_y)
det_xmin = np.min(det_x)
gt_ymax = np.max(gt_y)
gt_xmax = np.max(gt_x)
gt_ymin = np.min(gt_y)
gt_xmin = np.min(gt_x)
all_min_ymax = np.minimum(det_ymax, gt_ymax)
all_max_ymin = np.maximum(det_ymin, gt_ymin)
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
all_min_xmax = np.minimum(det_xmax, gt_xmax)
all_max_xmin = np.maximum(det_xmin, gt_xmin)
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
return intersect_heights * intersect_widths
def area_of_intersection(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.intersection(p2).area)
def area_of_union(det_x, det_y, gt_x, gt_y):
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
return float(p1.union(p2).area)
def iou(det_x, det_y, gt_x, gt_y):
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
def iod(det_x, det_y, gt_x, gt_y):
"""
This helper determine the fraction of intersection area over detection area
"""
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
area(det_x, det_y) + 1.0)
import paddle
import numpy as np
import copy
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
pad_num, tcl_bs):
label_list = label_list.numpy()
batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(batch):
for j in range(max_text_nums):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
label_list_t, tcl_bs)
label = []
tt = [l.tolist() for l in label_list]
for i in range(tcl_bs):
k = 0
for j in range(max_text_length):
if tt[i][j][0] != pad_num:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq
labels = np.argmax(probs_seq, axis=1)
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
detal = len(gather_info) // (pts_num - 1)
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list,
logits_map,
Lexicon_Table,
pts_num=6):
"""
CTC decoder using multiple processes.
"""
decoder_str = []
decoder_xys = []
for gather_info in gather_info_list:
if len(gather_info) < pts_num:
continue
dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num)
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
decoder_xys.append(xys_list)
return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
src_h, valid_set):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
offset_expand = 1.0
if valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
if valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
return poly_list, keep_str_list
def generate_pivot_list_fast(p_score,
p_char_maps,
f_direction,
Lexicon_Table,
score_thresh=0.5):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
return keep_yxs_list, decoded_str
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
def sort_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list, point_direction):
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 2)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
binary_tcl_map: h x w
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
max_append_num = 2 * append_num
left_list = []
right_list = []
for i in range(max_append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
if binary_tcl_map[ly, lx] > 0.5:
left_list.append((ly, lx))
else:
break
for i in range(max_append_num):
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
if binary_tcl_map[ry, rx] > 0.5:
right_list.append((ry, rx))
else:
break
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
pred_strs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
pred_strs.append(decoded_str)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return pred_strs, instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_slow(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from extract_textpoint_slow import *
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
instance_yxs_list, seq_strs = generate_pivot_list_fast(
p_score,
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
def pg_postprocess_slow(self):
p_score = self.outs_dict['f_score']
p_border = self.outs_dict['f_border']
p_char = self.outs_dict['f_char']
p_direction = self.outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
is_curved = self.valid_set == "totaltext"
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
data = {
'points': poly_list,
'texts': keep_str_list,
}
return data
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import time
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
...@@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -121,7 +121,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
return best_model_dict return best_model_dict
def save_model(net, def save_model(model,
optimizer, optimizer,
model_path, model_path,
logger, logger,
...@@ -133,7 +133,7 @@ def save_model(net, ...@@ -133,7 +133,7 @@ def save_model(net,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(net.state_dict(), model_prefix + '.pdparams') paddle.save(model.state_dict(), model_prefix + '.pdparams')
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
# save metric and config # save metric and config
......
...@@ -61,6 +61,7 @@ def get_image_file_list(img_file): ...@@ -61,6 +61,7 @@ def get_image_file_list(img_file):
imgs_lists.append(file_path) imgs_lists.append(file_path)
if len(imgs_lists) == 0: if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file)) raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists return imgs_lists
......
...@@ -3,8 +3,8 @@ scikit-image==0.17.2 ...@@ -3,8 +3,8 @@ scikit-image==0.17.2
imgaug==0.4.0 imgaug==0.4.0
pyclipper pyclipper
lmdb lmdb
opencv-python==4.2.0.32
tqdm tqdm
numpy numpy
visualdl visualdl
python-Levenshtein python-Levenshtein
\ No newline at end of file opencv-contrib-python==4.2.0.32
\ No newline at end of file
...@@ -32,7 +32,7 @@ setup( ...@@ -32,7 +32,7 @@ setup(
package_dir={'paddleocr': ''}, package_dir={'paddleocr': ''},
include_package_data=True, include_package_data=True,
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
version='2.0.3', version='2.0.6',
install_requires=requirements, install_requires=requirements,
license='Apache License 2.0', license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
......
...@@ -59,10 +59,10 @@ def main(): ...@@ -59,10 +59,10 @@ def main():
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, use_srn) eval_class, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -31,14 +31,6 @@ from ppocr.utils.logging import get_logger ...@@ -31,14 +31,6 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
...@@ -61,17 +53,19 @@ def main(): ...@@ -61,17 +53,19 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN": if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [ shape=[None, 1, 64, 256], dtype='float32'), [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 256, 1], shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec( dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1], shape=[None, max_text_length, 1], dtype="int64"),
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64") shape=[None, 8, max_text_length, max_text_length],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, max_text_length, max_text_length],
dtype="int64")
] ]
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
......
...@@ -39,7 +39,10 @@ class TextDetector(object): ...@@ -39,7 +39,10 @@ class TextDetector(object):
self.args = args self.args = args
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
pre_process_list = [{ pre_process_list = [{
'DetResizeForTest': None 'DetResizeForTest': {
'limit_side_len': args.det_limit_side_len,
'limit_type': args.det_limit_type
}
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
'std': [0.229, 0.224, 0.225], 'std': [0.229, 0.224, 0.225],
...@@ -62,6 +65,7 @@ class TextDetector(object): ...@@ -62,6 +65,7 @@ class TextDetector(object):
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
postprocess_params["use_dilation"] = args.use_dilation postprocess_params["use_dilation"] = args.use_dilation
postprocess_params["score_mode"] = args.det_db_score_mode
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
postprocess_params['name'] = 'EASTPostProcess' postprocess_params['name'] = 'EASTPostProcess'
postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["score_thresh"] = args.det_east_score_thresh
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import sys
import tools.infer.utility as utility
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
logger = get_logger()
class TextE2E(object):
def __init__(self, args):
self.args = args
self.e2e_algorithm = args.e2e_algorithm
pre_process_list = [{
'E2EResizeForTest': {}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {}
if self.e2e_algorithm == "PGNet":
pre_process_list[0] = {
'E2EResizeForTest': {
'max_side_len': args.e2e_limit_side_len,
'valid_set': 'totaltext'
}
}
postprocess_params['name'] = 'PGPostProcess'
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
postprocess_params["mode"] = args.e2e_pgnet_mode
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
else:
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
sys.exit(0)
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {}
if self.e2e_algorithm == 'PGNet':
preds['f_border'] = outputs[0]
preds['f_char'] = outputs[1]
preds['f_direction'] = outputs[2]
preds['f_score'] = outputs[3]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, strs, elapse
if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextE2E(args)
count = 0
total_time = 0
draw_img_save = "./inference_results"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
points, strs, elapse = text_detector(img)
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_e2e_res(points, strs, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"e2e_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1:
logger.info("Avg Time: {}".format(total_time / (count - 1)))
...@@ -41,6 +41,7 @@ class TextRecognizer(object): ...@@ -41,6 +41,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.max_text_length = args.max_text_length
postprocess_params = { postprocess_params = {
'name': 'CTCLabelDecode', 'name': 'CTCLabelDecode',
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
...@@ -186,8 +187,9 @@ class TextRecognizer(object): ...@@ -186,8 +187,9 @@ class TextRecognizer(object):
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else: else:
norm_img = self.process_image_srn( norm_img = self.process_image_srn(img_list[indices[ino]],
img_list[indices[ino]], self.rec_image_shape, 8, 25) self.rec_image_shape, 8,
self.max_text_length)
encoder_word_pos_list = [] encoder_word_pos_list = []
gsrm_word_pos_list = [] gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = [] gsrm_slf_attn_bias1_list = []
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -141,6 +142,7 @@ def sorted_boxes(dt_boxes): ...@@ -141,6 +142,7 @@ def sorted_boxes(dt_boxes):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
font_path = args.vis_font_path font_path = args.vis_font_path
...@@ -184,4 +186,18 @@ def main(args): ...@@ -184,4 +186,18 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) args = utility.parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
for process_id in range(total_process_num):
cmd = [sys.executable, "-u"] + sys.argv + [
"--process_id={}".format(process_id),
"--use_mp={}".format(False)
]
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
p_list.append(p)
for p in p_list:
p.wait()
else:
main(args)
...@@ -48,6 +48,7 @@ def parse_args(): ...@@ -48,6 +48,7 @@ def parse_args():
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--max_batch_size", type=int, default=10)
parser.add_argument("--use_dilation", type=bool, default=False) parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
...@@ -74,6 +75,20 @@ def parse_args(): ...@@ -74,6 +75,20 @@ def parse_args():
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5) parser.add_argument("--drop_score", type=float, default=0.5)
# params for e2e
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
parser.add_argument("--e2e_model_dir", type=str)
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
parser.add_argument("--e2e_limit_type", type=str, default='max')
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str) parser.add_argument("--cls_model_dir", type=str)
...@@ -85,6 +100,10 @@ def parse_args(): ...@@ -85,6 +100,10 @@ def parse_args():
parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False) parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
return parser.parse_args() return parser.parse_args()
...@@ -93,8 +112,10 @@ def create_predictor(args, mode, logger): ...@@ -93,8 +112,10 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir model_dir = args.det_model_dir
elif mode == 'cls': elif mode == 'cls':
model_dir = args.cls_model_dir model_dir = args.cls_model_dir
else: elif mode == 'rec':
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
else:
model_dir = args.e2e_model_dir
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
...@@ -148,6 +169,22 @@ def create_predictor(args, mode, logger): ...@@ -148,6 +169,22 @@ def create_predictor(args, mode, logger):
return predictor, input_tensor, output_tensors return predictor, input_tensor, output_tensors
def draw_e2e_res(dt_boxes, strs, img_path):
src_im = cv2.imread(img_path)
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
return src_im
def draw_text_det_res(dt_boxes, img_path): def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path) src_im = cv2.imread(img_path)
for box in dt_boxes: for box in dt_boxes:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
if len(dt_boxes) > 0:
src_im = img
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The e2e Image saved in {}".format(save_path))
def main():
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'],
global_config)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval()
with open(save_res_path, "wb") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['texts']
# write resule
dt_boxes_json = []
for poly, str in zip(points, strs):
tmp_json = {"transcription": str}
tmp_json['points'] = poly.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_e2e_res(points, strs, config, src_img, file)
logger.info("success!")
if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess()
main()
...@@ -73,35 +73,45 @@ def main(): ...@@ -73,35 +73,45 @@ def main():
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path',
"./output/rec/predicts_rec.txt")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
model.eval() model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file)) with open(save_res_path, "w") as fout:
with open(file, 'rb') as f: for file in get_image_file_list(config['Global']['infer_img']):
img = f.read() logger.info("infer_img: {}".format(file))
data = {'image': img} with open(file, 'rb') as f:
batch = transform(data, ops) img = f.read()
if config['Architecture']['algorithm'] == "SRN": data = {'image': img}
encoder_word_pos_list = np.expand_dims(batch[1], axis=0) batch = transform(data, ops)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) if config['Architecture']['algorithm'] == "SRN":
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
others = [ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list), others = [
paddle.to_tensor(gsrm_slf_attn_bias1_list), paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list) paddle.to_tensor(gsrm_word_pos_list),
] paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
images = np.expand_dims(batch[0], axis=0) ]
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN": images = np.expand_dims(batch[0], axis=0)
preds = model(images, others) images = paddle.to_tensor(images)
else: if config['Architecture']['algorithm'] == "SRN":
preds = model(images) preds = model(images, others)
post_result = post_process_class(preds) else:
for rec_reuslt in post_result: preds = model(images)
logger.info('\t result: {}'.format(rec_reuslt)) post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
if len(rec_reuslt) >= 2:
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
rec_reuslt[1]) + "\n")
logger.info("success!") logger.info("success!")
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import sys import sys
import platform
import yaml import yaml
import time import time
import shutil import shutil
...@@ -159,6 +160,8 @@ def train(config, ...@@ -159,6 +160,8 @@ def train(config,
eval_batch_step = config['Global']['eval_batch_step'] eval_batch_step = config['Global']['eval_batch_step']
global_step = 0 global_step = 0
if 'global_step' in pre_best_model_dict:
global_step = pre_best_model_dict['global_step']
start_eval_step = 0 start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2: if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0] start_eval_step = eval_batch_step[0]
...@@ -196,9 +199,11 @@ def train(config, ...@@ -196,9 +199,11 @@ def train(config,
train_reader_cost = 0.0 train_reader_cost = 0.0
batch_sum = 0 batch_sum = 0
batch_start = time.time() batch_start = time.time()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
train_reader_cost += time.time() - batch_start train_reader_cost += time.time() - batch_start
if idx >= len(train_dataloader): if idx >= max_iter:
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
...@@ -287,7 +292,8 @@ def train(config, ...@@ -287,7 +292,8 @@ def train(config,
is_best=True, is_best=True,
prefix='best_accuracy', prefix='best_accuracy',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join([ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items() '{}: {}'.format(k, v) for k, v in best_model_dict.items()
])) ]))
...@@ -309,7 +315,8 @@ def train(config, ...@@ -309,7 +315,8 @@ def train(config,
is_best=False, is_best=False,
prefix='latest', prefix='latest',
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0: if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model( save_model(
model, model,
...@@ -319,7 +326,8 @@ def train(config, ...@@ -319,7 +326,8 @@ def train(config,
is_best=False, is_best=False,
prefix='iter_epoch_{}'.format(epoch), prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict, best_model_dict=best_model_dict,
epoch=epoch) epoch=epoch,
global_step=global_step)
best_str = 'best metric, {}'.format(', '.join( best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str) logger.info(best_str)
...@@ -335,8 +343,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class, ...@@ -335,8 +343,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
max_iter = len(valid_dataloader) - 1 if platform.system(
) == "Windows" else len(valid_dataloader)
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= max_iter:
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
...@@ -375,7 +385,8 @@ def preprocess(is_train=False): ...@@ -375,7 +385,8 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册