提交 78d89d34 编写于 作者: L licx

modify docs for updates

上级 229265e6
...@@ -4,12 +4,11 @@ English | [简体中文](README_cn.md) ...@@ -4,12 +4,11 @@ English | [简体中文](README_cn.md)
PaddleOCR aims to create rich, leading, and practical OCR tools that help users train better models and apply them into practice. PaddleOCR aims to create rich, leading, and practical OCR tools that help users train better models and apply them into practice.
**Recent updates** **Recent updates**
- 2020.8.16 Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite) - 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite)
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight OCR model are provided. - 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight OCR model are provided.
- 2020.7.15, Add several related datasets, data annotation and synthesis tools. - 2020.7.15, Add several related datasets, data annotation and synthesis tools.
- 2020.7.9 Add a new model to support recognize the character "space".
- 2020.7.9 Add the data augument and learning rate decay strategies during training.
- [more](./doc/doc_en/update_en.md) - [more](./doc/doc_en/update_en.md)
## Features ## Features
...@@ -91,7 +90,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr ...@@ -91,7 +90,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
PaddleOCR open source text detection algorithms list: PaddleOCR open source text detection algorithms list:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947)) - [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research, comming soon) - [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
On the ICDAR2015 dataset, the text detection result is as follows: On the ICDAR2015 dataset, the text detection result is as follows:
...@@ -101,6 +100,7 @@ On the ICDAR2015 dataset, the text detection result is as follows: ...@@ -101,6 +100,7 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)| |EAST|MobileNetV3|81.67%|79.83%|80.74%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows: For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows:
|Model|Backbone|Configuration file|Pre-trained model| |Model|Backbone|Configuration file|Pre-trained model|
...@@ -120,7 +120,7 @@ PaddleOCR open-source text recognition algorithms list: ...@@ -120,7 +120,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085)) - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html)) - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research, comming soon) - [x] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research)
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
...@@ -134,8 +134,14 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -134,8 +134,14 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)| |STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)| |RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[Download link](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
**Note:** SRN model uses data expansion method to expand the two training sets mentioned above, and the expanded data can be downloaded from [Baidu Drive](todo).
The average accuracy of the two-stage training in the original paper is 89.74%, and that of one stage training in paddleocr is 88.33%. Both pre-trained weights can be downloaded [here](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar).
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows: We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows:
|Model|Backbone|Configuration file|Pre-trained model| |Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-| |-|-|-|-|
|ultra-lightweight OCR model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)| |ultra-lightweight OCR model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)|
......
...@@ -4,12 +4,11 @@ ...@@ -4,12 +4,11 @@
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。 PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
**近期更新** **近期更新**
- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统 - 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统
- 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark - 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark
- 2020.7.15 整理OCR相关数据集、常用数据标注以及合成工具 - 2020.7.15 整理OCR相关数据集、常用数据标注以及合成工具
- 2020.7.9 添加支持空格的识别模型,识别效果,预测及训练方式请参考快速开始和文本识别训练相关文档
- 2020.7.9 添加数据增强、学习率衰减策略,具体参考[配置文件](./doc/doc_ch/config.md)
- [more](./doc/doc_ch/update.md) - [more](./doc/doc_ch/update.md)
...@@ -93,7 +92,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ...@@ -93,7 +92,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
PaddleOCR开源的文本检测算法列表: PaddleOCR开源的文本检测算法列表:
- [x] EAST([paper](https://arxiv.org/abs/1704.03155)) - [x] EAST([paper](https://arxiv.org/abs/1704.03155))
- [x] DB([paper](https://arxiv.org/abs/1911.08947)) - [x] DB([paper](https://arxiv.org/abs/1911.08947))
- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研, coming soon) - [x] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研)
在ICDAR2015文本检测公开数据集上,算法效果如下: 在ICDAR2015文本检测公开数据集上,算法效果如下:
...@@ -103,8 +102,10 @@ PaddleOCR开源的文本检测算法列表: ...@@ -103,8 +102,10 @@ PaddleOCR开源的文本检测算法列表:
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)| |EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集共3w张数据,训练中文检测模型的相关配置和预训练文件如下: 使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集共3w张数据,训练中文检测模型的相关配置和预训练文件如下:
|模型|骨干网络|配置文件|预训练模型| |模型|骨干网络|配置文件|预训练模型|
|-|-|-|-| |-|-|-|-|
|超轻量中文模型|MobileNetV3|det_mv3_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)| |超轻量中文模型|MobileNetV3|det_mv3_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
...@@ -124,9 +125,6 @@ PaddleOCR开源的文本识别算法列表: ...@@ -124,9 +125,6 @@ PaddleOCR开源的文本识别算法列表:
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1)) - [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研) - [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研)
*备注:* SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: 参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接| |模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
...@@ -141,6 +139,9 @@ PaddleOCR开源的文本识别算法列表: ...@@ -141,6 +139,9 @@ PaddleOCR开源的文本识别算法列表:
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)| |SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
**说明:** SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](todo)上下载。
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下: 使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
|模型|骨干网络|配置文件|预训练模型| |模型|骨干网络|配置文件|预训练模型|
......
# 更新 # 更新
- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统 - 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统
- 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark - 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark
......
# RECENT UPDATES # RECENT UPDATES
- 2020.8.16 Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294)
- 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519) - 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519)
- 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite) - 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite)
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight Chinese OCR model are provided. - 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight Chinese OCR model are provided.
......
...@@ -73,7 +73,7 @@ class TrainReader(object): ...@@ -73,7 +73,7 @@ class TrainReader(object):
data_size_list.append(len(image_files)) data_size_list.append(len(image_files))
fetch_record_list.append(0) fetch_record_list.append(0)
image_batch, poly_batch = [], [] image_batch = []
# get a batch of img_fns and poly_fns # get a batch of img_fns and poly_fns
for i in range(0, len(batch_size_list)): for i in range(0, len(batch_size_list)):
bs = batch_size_list[i] bs = batch_size_list[i]
......
...@@ -593,38 +593,6 @@ class SASTProcessTrain(object): ...@@ -593,38 +593,6 @@ class SASTProcessTrain(object):
return np.array(quad_list) return np.array(quad_list)
def rotate_im_poly(self, im, text_polys):
"""
rotate image with 90 / 180 / 270 degre
"""
im_w, im_h = im.shape[1], im.shape[0]
dst_im = im.copy()
dst_polys = []
rand_degree_ratio = np.random.rand()
rand_degree_cnt = 1
#if rand_degree_ratio > 0.333 and rand_degree_ratio < 0.666:
# rand_degree_cnt = 2
#elif rand_degree_ratio > 0.666:
if rand_degree_ratio > 0.5:
rand_degree_cnt = 3
for i in range(rand_degree_cnt):
dst_im = np.rot90(dst_im)
rot_degree = -90 * rand_degree_cnt
rot_angle = rot_degree * math.pi / 180.0
n_poly = text_polys.shape[0]
cx, cy = 0.5 * im_w, 0.5 * im_h
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
for i in range(n_poly):
wordBB = text_polys[i]
poly = []
for j in range(4):#16->4
sx, sy = wordBB[j][0], wordBB[j][1]
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (sy - cy) + ncx
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (sy - cy) + ncy
poly.append([dx, dy])
dst_polys.append(poly)
return dst_im, np.array(dst_polys, dtype=np.float32)
def extract_polys(self, poly_txt_path): def extract_polys(self, poly_txt_path):
""" """
Read text_polys, txt_tags, txts from give txt file. Read text_polys, txt_tags, txts from give txt file.
...@@ -653,39 +621,13 @@ class SASTProcessTrain(object): ...@@ -653,39 +621,13 @@ class SASTProcessTrain(object):
return None return None
if text_polys.shape[0] == 0: if text_polys.shape[0] == 0:
return None return None
# #add rotate cases
# if np.random.rand() < 0.5:
# im, text_polys = self.rotate_im_poly(im, text_polys)
h, w, _ = im.shape h, w, _ = im.shape
# text_polys, text_tags = self.check_and_validate_polys(text_polys,
# text_tags, h, w)
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w)) text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
if text_polys.shape[0] == 0: if text_polys.shape[0] == 0:
return None return None
# # random scale this image
# rd_scale = np.random.choice(self.random_scale)
# im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
# text_polys *= rd_scale
# if np.random.rand() < self.background_ratio:
# outs = self.crop_background_infor(im, text_polys, text_tags,
# text_strs)
# else:
# outs = self.crop_foreground_infor(im, text_polys, text_tags,
# text_strs)
# if outs is None:
# return None
# im, score_map, geo_map, training_mask = outs
# score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
# geo_map = np.swapaxes(geo_map, 1, 2)
# geo_map = np.swapaxes(geo_map, 1, 0)
# geo_map = geo_map[:, ::4, ::4].astype(np.float32)
# training_mask = training_mask[np.newaxis, ::4, ::4]
# training_mask = training_mask.astype(np.float32)
# return im, score_map, geo_map, training_mask
#set aspect ratio and keep area fix #set aspect ratio and keep area fix
asp_scales = np.arange(1.0, 1.55, 0.1) asp_scales = np.arange(1.0, 1.55, 0.1)
asp_scale = np.random.choice(asp_scales) asp_scale = np.random.choice(asp_scales)
...@@ -781,28 +723,9 @@ class SASTProcessTrain(object): ...@@ -781,28 +723,9 @@ class SASTProcessTrain(object):
im_padded[:, :, 0] /= (255.0 * 0.225) im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1)) im_padded = im_padded.transpose((2, 0, 1))
# images.append(im_padded[::-1, :, :])
# tcl_maps.append(score_map[np.newaxis, :, :])
# border_maps.append(border_map.transpose((2, 0, 1)))
# training_masks.append(training_mask[np.newaxis, :, :])
# tvos.append(tvo_map.transpose((2, 0, 1)))
# tcos.append(tco_map.transpose((2, 0, 1)))
# # After a batch should begin
# if len(images) == batch_size:
# yield np.array(images, dtype=np.float32), \
# np.array(tcl_maps, dtype=np.float32), \
# np.array(tvos, dtype=np.float32), \
# np.array(tcos, dtype=np.float32), \
# np.array(border_maps, dtype=np.float32), \
# np.array(training_masks, dtype=np.float32), \
# images, tcl_maps, border_maps, training_masks = [], [], [], []
# tvos, tcos = [], []
# return im_padded, score_map, border_map, training_mask, tvo_map, tco_map
return im_padded[::-1, :, :], score_map[np.newaxis, :, :], border_map.transpose((2, 0, 1)), training_mask[np.newaxis, :, :], tvo_map.transpose((2, 0, 1)), tco_map.transpose((2, 0, 1)) return im_padded[::-1, :, :], score_map[np.newaxis, :, :], border_map.transpose((2, 0, 1)), training_mask[np.newaxis, :, :], tvo_map.transpose((2, 0, 1)), tco_map.transpose((2, 0, 1))
class SASTProcessTest(object): class SASTProcessTest(object):
""" """
SAST process function for test SAST process function for test
...@@ -814,46 +737,6 @@ class SASTProcessTest(object): ...@@ -814,46 +737,6 @@ class SASTProcessTest(object):
else: else:
self.max_side_len = 2400 self.max_side_len = 2400
# def resize_image(self, im):
# """
# resize image to a size multiple of 32 which is required by the network
# :param im: the resized image
# :param max_side_len: limit of max image size to avoid out of memory in gpu
# :return: the resized image and the resize ratio
# """
# max_side_len = self.max_side_len
# h, w, _ = im.shape
# resize_w = w
# resize_h = h
# # limit the max side
# if max(resize_h, resize_w) > max_side_len:
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
# else:
# ratio = 1.
# resize_h = int(resize_h * ratio)
# resize_w = int(resize_w * ratio)
# if resize_h % 32 == 0:
# resize_h = resize_h
# elif resize_h // 32 <= 1:
# resize_h = 32
# else:
# resize_h = (resize_h // 32 - 1) * 32
# if resize_w % 32 == 0:
# resize_w = resize_w
# elif resize_w // 32 <= 1:
# resize_w = 32
# else:
# resize_w = (resize_w // 32 - 1) * 32
# im = cv2.resize(im, (int(resize_w), int(resize_h)))
# ratio_h = resize_h / float(h)
# ratio_w = resize_w / float(w)
# return im, (ratio_h, ratio_w)
def resize_image(self, im): def resize_image(self, im):
""" """
resize image to a size multiple of max_stride which is required by the network resize image to a size multiple of max_stride which is required by the network
......
...@@ -105,7 +105,6 @@ class DetModel(object): ...@@ -105,7 +105,6 @@ class DetModel(object):
input_mask = fluid.layers.data( input_mask = fluid.layers.data(
name='mask', shape=[1, 128, 128], dtype='float32') name='mask', shape=[1, 128, 128], dtype='float32')
input_tvo = fluid.layers.data( input_tvo = fluid.layers.data(
# name='tvo', shape=[5, 128, 128], dtype='float32')
name='tvo', shape=[9, 128, 128], dtype='float32') name='tvo', shape=[9, 128, 128], dtype='float32')
input_tco = fluid.layers.data( input_tco = fluid.layers.data(
name='tco', shape=[3, 128, 128], dtype='float32') name='tco', shape=[3, 128, 128], dtype='float32')
......
...@@ -24,7 +24,7 @@ from collections import OrderedDict ...@@ -24,7 +24,7 @@ from collections import OrderedDict
class SASTHead(object): class SASTHead(object):
""" """
SAST: SAST:
see arxiv: https:// see arxiv: https://arxiv.org/abs/1908.05498
args: args:
params(dict): the super parameters for network build params(dict): the super parameters for network build
""" """
...@@ -89,7 +89,7 @@ class SASTHead(object): ...@@ -89,7 +89,7 @@ class SASTHead(object):
g[i] = fluid.layers.relu(g[i]) g[i] = fluid.layers.relu(g[i])
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i) g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i)
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i) g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i)
print("g[{}] shape: {}".format(i, g[i].shape)) # print("g[{}] shape: {}".format(i, g[i].shape))
g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2]) g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2])
g[2] = fluid.layers.relu(g[2]) g[2] = fluid.layers.relu(g[2])
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2], g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
...@@ -106,14 +106,14 @@ class SASTHead(object): ...@@ -106,14 +106,14 @@ class SASTHead(object):
f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3') f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3')
f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4') f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4')
f_score = fluid.layers.sigmoid(f_score) f_score = fluid.layers.sigmoid(f_score)
print("f_score shape: {}".format(f_score.shape)) # print("f_score shape: {}".format(f_score.shape))
#f_boder #f_boder
f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1') f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1')
f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2') f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2')
f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3') f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3')
f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4') f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4')
print("f_border shape: {}".format(f_border.shape)) # print("f_border shape: {}".format(f_border.shape))
return f_score, f_border return f_score, f_border
...@@ -124,14 +124,14 @@ class SASTHead(object): ...@@ -124,14 +124,14 @@ class SASTHead(object):
f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2') f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2')
f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3') f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3')
f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4') f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4')
print("f_tvo shape: {}".format(f_tvo.shape)) # print("f_tvo shape: {}".format(f_tvo.shape))
#f_tco #f_tco
f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1') f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1')
f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2') f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2')
f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3') f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3')
f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4') f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4')
print("f_tco shape: {}".format(f_tco.shape)) # print("f_tco shape: {}".format(f_tco.shape))
return f_tvo, f_tco return f_tvo, f_tco
...@@ -161,7 +161,7 @@ class SASTHead(object): ...@@ -161,7 +161,7 @@ class SASTHead(object):
#weighted sum #weighted sum
fh_weight = fluid.layers.matmul(fh_attn, fh_g) fh_weight = fluid.layers.matmul(fh_attn, fh_g)
fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128]) fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128])
print("fh_weight: {}".format(fh_weight.shape)) # print("fh_weight: {}".format(fh_weight.shape))
fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2]) fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2])
fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight') fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight')
#short cut #short cut
...@@ -187,7 +187,7 @@ class SASTHead(object): ...@@ -187,7 +187,7 @@ class SASTHead(object):
#weighted sum #weighted sum
fv_weight = fluid.layers.matmul(fv_attn, fv_g) fv_weight = fluid.layers.matmul(fv_attn, fv_g)
fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128]) fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128])
print("fv_weight: {}".format(fv_weight.shape)) # print("fv_weight: {}".format(fv_weight.shape))
fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1]) fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1])
fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight') fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight')
#short cut #short cut
...@@ -199,22 +199,22 @@ class SASTHead(object): ...@@ -199,22 +199,22 @@ class SASTHead(object):
return f_attn return f_attn
def __call__(self, blocks, with_cab=False): def __call__(self, blocks, with_cab=False):
for k, v in blocks.items(): # for k, v in blocks.items():
print(k, v.shape) # print(k, v.shape)
#down fpn #down fpn
f_down = self.FPN_Down_Fusion(blocks) f_down = self.FPN_Down_Fusion(blocks)
print("f_down shape: {}".format(f_down.shape)) # print("f_down shape: {}".format(f_down.shape))
#up fpn #up fpn
f_up = self.FPN_Up_Fusion(blocks) f_up = self.FPN_Up_Fusion(blocks)
print("f_up shape: {}".format(f_up.shape)) # print("f_up shape: {}".format(f_up.shape))
#fusion #fusion
f_common = fluid.layers.elementwise_add(x=f_down, y=f_up) f_common = fluid.layers.elementwise_add(x=f_down, y=f_up)
f_common = fluid.layers.relu(f_common) f_common = fluid.layers.relu(f_common)
print("f_common: {}".format(f_common.shape)) # print("f_common: {}".format(f_common.shape))
if self.with_cab: if self.with_cab:
print('enhence f_common with CAB.') # print('enhence f_common with CAB.')
f_common = self.cross_attention(f_common) f_common = self.cross_attention(f_common)
f_score, f_border= self.SAST_Header1(f_common) f_score, f_border= self.SAST_Header1(f_common)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册