未验证 提交 b7d99acd 编写于 作者: U user1018 提交者: GitHub

update recovery (#7259)

* update recovery

* update recovery

* update recovery

* update recovery

* update recovery
上级 94710ae3
...@@ -50,7 +50,7 @@ def get_check_global_params(mode): ...@@ -50,7 +50,7 @@ def get_check_global_params(mode):
def _check_image_file(path): def _check_image_file(path):
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
return any([path.lower().endswith(e) for e in img_end]) return any([path.lower().endswith(e) for e in img_end])
...@@ -59,7 +59,7 @@ def get_image_file_list(img_file): ...@@ -59,7 +59,7 @@ def get_image_file_list(img_file):
if img_file is None or not os.path.exists(img_file): if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file)) raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif'} img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
if os.path.isfile(img_file) and _check_image_file(img_file): if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file) imgs_lists.append(img_file)
elif os.path.isdir(img_file): elif os.path.isdir(img_file):
...@@ -73,7 +73,7 @@ def get_image_file_list(img_file): ...@@ -73,7 +73,7 @@ def get_image_file_list(img_file):
return imgs_lists return imgs_lists
def check_and_read_gif(img_path): def check_and_read(img_path):
if os.path.basename(img_path)[-3:] in ['gif', 'GIF']: if os.path.basename(img_path)[-3:] in ['gif', 'GIF']:
gif = cv2.VideoCapture(img_path) gif = cv2.VideoCapture(img_path)
ret, frame = gif.read() ret, frame = gif.read()
...@@ -84,8 +84,26 @@ def check_and_read_gif(img_path): ...@@ -84,8 +84,26 @@ def check_and_read_gif(img_path):
if len(frame.shape) == 2 or frame.shape[-1] == 1: if len(frame.shape) == 2 or frame.shape[-1] == 1:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1] imgvalue = frame[:, :, ::-1]
return imgvalue, True return imgvalue, True, False
return None, False elif os.path.basename(img_path)[-3:] in ['pdf']:
import fitz
from PIL import Image
imgs = []
with fitz.open(img_path) as pdf:
for pg in range(0, pdf.pageCount):
page = pdf[pg]
mat = fitz.Matrix(2, 2)
pm = page.getPixmap(matrix=mat, alpha=False)
# if width or height > 2000 pixels, don't enlarge the image
if pm.width > 2000 or pm.height > 2000:
pm = page.getPixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
imgs.append(img)
return imgs, False, True
return None, False, False
def load_vqa_bio_label_maps(label_map_path): def load_vqa_bio_label_maps(label_map_path):
......
- [1. 简介](#1-简介)
- [2. 安装](#2-安装)
- [2.1 安装PaddlePaddle](#21-安装paddlepaddle)
- [2.2 安装PaddleDetection](#22-安装paddledetection)
- [3. 数据准备](#3-数据准备)
- [3.1 英文数据集](#31-英文数据集)
- [3.2 更多数据集](#32-更多数据集)
- [4. 开始训练](#4-开始训练)
- [4.1 启动训练](#41-启动训练)
- [4.2 FGD蒸馏训练](#42-FGD蒸馏训练)
- [5. 模型评估与预测](#5-模型评估与预测)
- [5.1 指标评估](#51-指标评估)
- [5.2 测试版面分析结果](#52-测试版面分析结果)
- [6 模型导出与预测](#6-模型导出与预测)
- [6.1 模型导出](#61-模型导出)
- [6.2 模型推理](#62-模型推理)
# 版面分析
## 1. 简介
版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发。
<div align="center">
<img src="../docs/layout/layout.png" width="800">
</div>
## 2. 安装依赖
### 2.1. 安装PaddlePaddle
- **(1) 安装PaddlePaddle**
```bash
python3 -m pip install --upgrade pip
# GPU安装
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple
# CPU安装
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple
```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
### 2.2. 安装PaddleDetection
- **(1)下载PaddleDetection源码**
```bash
git clone https://github.com/PaddlePaddle/PaddleDetection.git
```
- **(2)安装其他依赖 **
```bash
cd PaddleDetection
python3 -m pip install -r requirements.txt
```
## 3. 数据准备
如果希望直接体验预测过程,可以跳过数据准备,下载我们提供的预训练模型。
### 3.1. 英文数据集
下载文档分析数据集[PubLayNet](https://developer.ibm.com/exchanges/data/all/publaynet/)(数据集96G),包含5个类:`{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}`
```
# 下载数据
wget https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/publaynet.tar.gz
# 解压数据
tar -xvf publaynet.tar.gz
```
解压之后的**目录结构:**
```
|-publaynet
|- test
|- PMC1277013_00004.jpg
|- PMC1291385_00002.jpg
| ...
|- train.json
|- train
|- PMC1291385_00002.jpg
|- PMC1277013_00004.jpg
| ...
|- val.json
|- val
|- PMC538274_00004.jpg
|- PMC539300_00004.jpg
| ...
```
**数据分布:**
| File or Folder | Description | num |
| :------------- | :------------- | ------- |
| `train/` | 训练集图片 | 335,703 |
| `val/` | 验证集图片 | 11,245 |
| `test/` | 测试集图片 | 11,405 |
| `train.json` | 训练集标注文件 | - |
| `val.json` | 验证集标注文件 | - |
**标注格式:**
json文件包含所有图像的标注,数据以字典嵌套的方式存放,包含以下key:
- info,表示标注文件info。
- licenses,表示标注文件licenses。
- images,表示标注文件中图像信息列表,每个元素是一张图像的信息。如下为其中一张图像的信息:
```
{
'file_name': 'PMC4055390_00006.jpg', # file_name
'height': 601, # image height
'width': 792, # image width
'id': 341427 # image id
}
```
- annotations,表示标注文件中目标物体的标注信息列表,每个元素是一个目标物体的标注信息。如下为其中一个目标物体的标注信息:
```
{
'segmentation': # 物体的分割标注
'area': 60518.099043117836, # 物体的区域面积
'iscrowd': 0, # iscrowd
'image_id': 341427, # image id
'bbox': [50.58, 490.86, 240.15, 252.16], # bbox [x1,y1,w,h]
'category_id': 1, # category_id
'id': 3322348 # image id
}
```
### 3.2. 更多数据集
我们提供了CDLA(中文版面分析)、TableBank(表格版面分析)等数据集的下连接,处理为上述标注文件json格式,即可以按相同方式进行训练。
| dataset | 简介 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | 用于表格检测(TRACKA)和表格识别(TRACKB)。图片类型包含历史数据集(以cTDaR_t0开头,如cTDaR_t00872.jpg)和现代数据集(以cTDaR_t1开头,cTDaR_t10482.jpg)。 |
| [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | 手动注释公开的年度报告中的图形或页面而构建的数据集,包含5类:table, figure, natural image, logo, and signature |
| [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Table、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation |
| [TableBank](https://github.com/doc-analysis/TableBank) | 用于表格检测和识别大型数据集,包含Word和Latex2种文档格式 |
| [DocBank](https://github.com/doc-analysis/DocBank) | 使用弱监督方法构建的大规模数据集(500K文档页面),用于文档布局分析,包含12类:Author、Caption、Date、Equation、Figure、Footer、List、Paragraph、Reference、Section、Table、Title |
## 4. 开始训练
提供了训练脚本、评估脚本和预测脚本,本节将以PubLayNet预训练模型为例进行讲解。
如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型,并跳过本部分。
```
mkdir pretrained_model
cd pretrained_model
# 下载并解压PubLayNet预训练模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout.pdparams
```
### 4.1. 启动训练
开始训练:
* 修改配置文件
如果你希望训练自己的数据集,需要修改配置文件中的数据配置、类别数。
`configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 为例,修改的内容如下所示。
```yaml
metric: COCO
# 类别数
num_classes: 5
TrainDataset:
!COCODataSet
# 修改为你自己的训练数据目录
image_dir: train
# 修改为你自己的训练数据标签文件
anno_path: train.json
# 修改为你自己的训练数据根目录
dataset_dir: /root/publaynet/
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset:
!COCODataSet
# 修改为你自己的验证数据目录
image_dir: val
# 修改为你自己的验证数据标签文件
anno_path: val.json
# 修改为你自己的验证数据根目录
dataset_dir: /root/publaynet/
TestDataset:
!ImageFolder
# 修改为你自己的测试数据标签文件
anno_path: /root/publaynet/val.json
```
* 开始训练,在训练时,会默认下载PP-PicoDet预训练模型,这里无需预先下载。
```bash
# GPU训练 支持单卡,多卡训练
# 训练日志会自动保存到 log 目录中
# 单卡训练
python3 tools/train.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
--eval
# 多卡训练,通过--gpus参数指定卡号
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
--eval
```
正常启动训练后,会看到以下log输出:
```
[08/15 04:02:30] ppdet.utils.checkpoint INFO: Finish loading model weights: /root/.cache/paddle/weights/LCNet_x1_0_pretrained.pdparams
[08/15 04:02:46] ppdet.engine INFO: Epoch: [0] [ 0/1929] learning_rate: 0.040000 loss_vfl: 1.216707 loss_bbox: 1.142163 loss_dfl: 0.544196 loss: 2.903065 eta: 17 days, 13:50:26 batch_cost: 15.7452 data_cost: 2.9112 ips: 1.5243 images/s
[08/15 04:03:19] ppdet.engine INFO: Epoch: [0] [ 20/1929] learning_rate: 0.064000 loss_vfl: 1.180627 loss_bbox: 0.939552 loss_dfl: 0.442436 loss: 2.628206 eta: 2 days, 12:18:53 batch_cost: 1.5770 data_cost: 0.0008 ips: 15.2184 images/s
[08/15 04:03:47] ppdet.engine INFO: Epoch: [0] [ 40/1929] learning_rate: 0.088000 loss_vfl: 0.543321 loss_bbox: 1.071401 loss_dfl: 0.457817 loss: 2.057003 eta: 2 days, 0:07:03 batch_cost: 1.3190 data_cost: 0.0007 ips: 18.1954 images/s
[08/15 04:04:12] ppdet.engine INFO: Epoch: [0] [ 60/1929] learning_rate: 0.112000 loss_vfl: 0.630989 loss_bbox: 0.859183 loss_dfl: 0.384702 loss: 1.883143 eta: 1 day, 19:01:29 batch_cost: 1.2177 data_cost: 0.0006 ips: 19.7087 images/s
```
- `--eval`表示训练的同时,进行评估, 评估过程中默认将最佳模型,保存为 `output/picodet_lcnet_x1_0_layout/best_accuracy`
**注意,预测/评估时的配置文件请务必与训练一致。**
### 4.2. FGD蒸馏训练
PaddleDetection支持了基于FGD([Focal and Global Knowledge Distillation for Detectors](https://arxiv.org/abs/2111.11837v1))蒸馏的目标检测模型训练过程,FGD蒸馏分为两个部分`Focal``Global``Focal`蒸馏分离图像的前景和背景,让学生模型分别关注教师模型的前景和背景部分特征的关键像素;`Global`蒸馏部分重建不同像素之间的关系并将其从教师转移到学生,以补偿`Focal`蒸馏中丢失的全局信息。
更换数据集,修改【TODO】配置中的数据配置、类别数,具体可以参考4.1。启动训练:
```bash
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
--slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \
--eval
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定压缩策略配置文件。
## 5. 模型评估与预测
### 5.1. 指标评估
训练中模型参数默认保存在`output/picodet_lcnet_x1_0_layout`目录下。在评估指标时,需要设置`weights`指向保存的参数文件。评估数据集可以通过 `configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 修改`EvalDataset`中的 `image_dir``anno_path``dataset_dir` 设置。
```bash
# GPU 评估, weights 为待测权重
python3 tools/eval.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
-o weigths=./output/picodet_lcnet_x1_0_layout/best_model
```
会输出以下信息,打印出mAP、AP0.5等信息。
```py
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.935
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.979
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.956
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.404
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.782
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.969
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.539
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.938
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.949
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.495
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.818
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.978
[08/15 07:07:09] ppdet.engine INFO: Total sample number: 11245, averge FPS: 24.405059207157436
[08/15 07:07:09] ppdet.engine INFO: Best test bbox ap is 0.935.
```
使用FGD蒸馏模型进行评估:
```
python3 tools/eval.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
--slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \
-o weights=output/picodet_lcnet_x2_5_layout/best_model
```
- `-c`: 指定模型配置文件。
- `--slim_config`: 指定蒸馏策略配置文件。
- `-o weights`: 指定蒸馏算法训好的模型路径。
### 5.2. 测试版面分析结果
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml` 完成了模型的训练过程。
使用 PaddleDetection 训练好的模型,您可以使用如下命令进行中文模型预测。
```bash
python3 tools/infer.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
-o weights='output/picodet_lcnet_x1_0_layout/best_model.pdparams' \
--infer_img='docs/images/layout.jpg' \
--output_dir=output_dir/ \
--draw_threshold=0.4
```
- `--infer_img`: 推理单张图片,也可以通过`--infer_dir`推理文件中的所有图片。
- `--output_dir`: 指定可视化结果保存路径。
- `--draw_threshold`:指定绘制结果框的NMS阈值。
预测图片如下所示,图片会存储在`output_dir`路径中。
使用FGD蒸馏模型进行测试:
```
python3 tools/infer.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
--slim_config configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x2_5_layout.yml \
-o weights='output/picodet_lcnet_x2_5_layout/best_model.pdparams' \
--infer_img='docs/images/layout.jpg' \
--output_dir=output_dir/ \
--draw_threshold=0.4
```
## 6. 模型导出与预测
### 6.1 模型导出
inference 模型(`paddle.jit.save`保存的模型) 一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。 训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。 与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
版面分析模型转inference模型步骤如下:
```bash
python3 tools/export_model.py \
-c configs/picodet/legacy_model/application/layout_detection/picodet_lcnet_x1_0_layout.yml \
-o weights=output/picodet_lcnet_x1_0_layout/best_model \
--output_dir=output_inference/
```
* 如无需导出后处理,请指定:`-o export.benchmark=True`(如果-o已出现过,此处删掉-o)
* 如无需导出NMS,请指定:`-o export.nms=False`
转换成功后,在目录下有三个文件:
```
output_inference/picodet_lcnet_x1_0_layout/
├── model.pdiparams # inference模型的参数文件
├── model.pdiparams.info # inference模型的参数信息,可忽略
└── model.pdmodel # inference模型的模型结构文件
```
FGD蒸馏模型转inference模型步骤如下:
```bash
python3 tools/export_model.py \
-c configs/picodet/legacy_model/application/publayernet_lcnet_x1_5/picodet_student.yml \
--slim_config configs/picodet/legacy_model/application/publayernet_lcnet_x1_5/picodet_teacher.yml \
-o weights=./output/picodet_lcnet_x2_5_layout/best_model \
--output_dir=output_inference/
```
### 6.2 模型推理
版面恢复任务进行推理,可以执行如下命令:
```bash
python3 deploy/python/infer.py \
--model_dir=output_inference/picodet_lcnet_x1_0_layout/ \
--image_file=docs/images/layout.jpg \
--device=CPU
```
- --device:指定GPU、CPU设备
模型推理完成,会看到以下log输出
```
------------------------------------------
----------- Model Configuration -----------
Model Arch: PicoDet
Transform Order:
--transform op: Resize
--transform op: NormalizeImage
--transform op: Permute
--transform op: PadStride
--------------------------------------------
class_id:0, confidence:0.9921, left_top:[20.18,35.66],right_bottom:[341.58,600.99]
class_id:0, confidence:0.9914, left_top:[19.77,611.42],right_bottom:[341.48,901.82]
class_id:0, confidence:0.9904, left_top:[369.36,375.10],right_bottom:[691.29,600.59]
class_id:0, confidence:0.9835, left_top:[369.60,608.60],right_bottom:[691.38,736.72]
class_id:0, confidence:0.9830, left_top:[369.58,805.38],right_bottom:[690.97,901.80]
class_id:0, confidence:0.9716, left_top:[383.68,271.44],right_bottom:[688.93,335.39]
class_id:0, confidence:0.9452, left_top:[370.82,34.48],right_bottom:[688.10,63.54]
class_id:1, confidence:0.8712, left_top:[370.84,771.03],right_bottom:[519.30,789.13]
class_id:3, confidence:0.9856, left_top:[371.28,67.85],right_bottom:[685.73,267.72]
save result to: output/layout.jpg
Test iter 0
------------------ Inference Time Info ----------------------
total_time(ms): 2196.0, img_num: 1
average latency time(ms): 2196.00, QPS: 0.455373
preprocess_time(ms): 2172.50, inference_time(ms): 11.90, postprocess_time(ms): 11.60
```
- Model:模型结构
- Transform Order:预处理操作
- class_id、confidence、left_top、right_bottom:分别表示类别id、置信度、左上角坐标、右下角坐标
- save result to:可视化版面分析结果保存路径,默认保存到`./output`文件夹
- Inference Time Info:推理时间,其中preprocess_time表示预处理耗时,inference_time表示模型预测耗时,postprocess_time表示后处理耗时
可视化版面结果如下图所示
<div align="center">
<img src="../docs/layout/layout_res.jpg" width="800">
</div>
## Citations
```
@inproceedings{zhong2019publaynet,
title={PubLayNet: largest dataset ever for document layout analysis},
author={Zhong, Xu and Tang, Jianbin and Yepes, Antonio Jimeno},
booktitle={2019 International Conference on Document Analysis and Recognition (ICDAR)},
year={2019},
volume={},
number={},
pages={1015-1022},
doi={10.1109/ICDAR.2019.00166},
ISSN={1520-5363},
month={Sep.},
organization={IEEE}
}
@inproceedings{yang2022focal,
title={Focal and global knowledge distillation for detectors},
author={Yang, Zhendong and Li, Zhe and Jiang, Xiaohu and Gong, Yuan and Yuan, Zehuan and Zhao, Danpei and Yuan, Chun},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={4643--4652},
year={2022}
}
```
...@@ -28,13 +28,12 @@ import time ...@@ -28,13 +28,12 @@ import time
import logging import logging
from copy import deepcopy from copy import deepcopy
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.recovery.recovery_to_doc import convert_info_docx
logger = get_logger() logger = get_logger()
...@@ -78,7 +77,7 @@ class StructureSystem(object): ...@@ -78,7 +77,7 @@ class StructureSystem(object):
elif self.mode == 'vqa': elif self.mode == 'vqa':
raise NotImplementedError raise NotImplementedError
def __call__(self, img, return_ocr_result_in_table=False): def __call__(self, img, img_idx=0, return_ocr_result_in_table=False):
time_dict = { time_dict = {
'image_orientation': 0, 'image_orientation': 0,
'layout': 0, 'layout': 0,
...@@ -143,8 +142,8 @@ class StructureSystem(object): ...@@ -143,8 +142,8 @@ class StructureSystem(object):
time_dict['det'] += ocr_time_dict['det'] time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec'] time_dict['rec'] += ocr_time_dict['rec']
# remove style char, # remove style char,
# when using the recognition model trained on the PubtabNet dataset, # when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b> # it will recognize the text format in the table, such as <b>
style_token = [ style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>', '<strike>', '<strike>', '<sup>', '</sub>', '<b>',
...@@ -169,7 +168,8 @@ class StructureSystem(object): ...@@ -169,7 +168,8 @@ class StructureSystem(object):
'type': region['label'].lower(), 'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2], 'bbox': [x1, y1, x2, y2],
'img': roi_img, 'img': roi_img,
'res': res 'res': res,
'img_idx': img_idx
}) })
end = time.time() end = time.time()
time_dict['all'] = end - start time_dict['all'] = end - start
...@@ -179,26 +179,29 @@ class StructureSystem(object): ...@@ -179,26 +179,29 @@ class StructureSystem(object):
return None, None return None, None
def save_structure_res(res, save_folder, img_name): def save_structure_res(res, save_folder, img_name, img_idx=0):
excel_save_folder = os.path.join(save_folder, img_name) excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True) os.makedirs(excel_save_folder, exist_ok=True)
res_cp = deepcopy(res) res_cp = deepcopy(res)
# save res # save res
with open( with open(
os.path.join(excel_save_folder, 'res.txt'), 'w', os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)),
'w',
encoding='utf8') as f: encoding='utf8') as f:
for region in res_cp: for region in res_cp:
roi_img = region.pop('img') roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region))) f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'table' and len(region[ if region['type'].lower() == 'table' and len(region[
'res']) > 0 and 'html' in region['res']: 'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder, excel_path = os.path.join(
'{}.xlsx'.format(region['bbox'])) excel_save_folder,
'{}_{}.xlsx'.format(region['bbox'], img_idx))
to_excel(region['res']['html'], excel_path) to_excel(region['res']['html'], excel_path)
elif region['type'] == 'figure': elif region['type'].lower() == 'figure':
img_path = os.path.join(excel_save_folder, img_path = os.path.join(
'{}.jpg'.format(region['bbox'])) excel_save_folder,
'{}_{}.jpg'.format(region['bbox'], img_idx))
cv2.imwrite(img_path, roi_img) cv2.imwrite(img_path, roi_img)
...@@ -214,28 +217,75 @@ def main(args): ...@@ -214,28 +217,75 @@ def main(args):
for i, image_file in enumerate(image_file_list): for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file)) logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file) img, flag_gif, flag_pdf = check_and_read(image_file)
img_name = os.path.basename(image_file).split('.')[0] img_name = os.path.basename(image_file).split('.')[0]
if not flag: if not flag_gif and not flag_pdf:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure': if not flag_pdf:
save_structure_res(res, save_folder, img_name) if img is None:
draw_img = draw_structure_result(img, res, args.vis_font_path) logger.error("error in loading image:{}".format(image_file))
img_save_path = os.path.join(save_folder, img_name, 'show.jpg') continue
elif structure_sys.mode == 'vqa': res, time_dict = structure_sys(img)
raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path) if structure_sys.mode == 'structure':
# img_save_path = os.path.join(save_folder, img_name + '.jpg') save_structure_res(res, save_folder, img_name)
cv2.imwrite(img_save_path, draw_img) draw_img = draw_structure_result(img, res, args.vis_font_path)
logger.info('result save to {}'.format(img_save_path)) img_save_path = os.path.join(save_folder, img_name, 'show.jpg')
if args.recovery: elif structure_sys.mode == 'vqa':
convert_info_docx(img, res, save_folder, img_name) raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
try:
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
convert_info_docx(img, res, save_folder, img_name,
args.save_pdf)
except Exception as ex:
logger.error(
"error in layout recovery image:{}, err msg: {}".format(
image_file, ex))
continue
else:
pdf_imgs = img
all_res = []
for index, img in enumerate(pdf_imgs):
res, time_dict = structure_sys(img, index)
if structure_sys.mode == 'structure' and res != []:
save_structure_res(res, save_folder, img_name, index)
draw_img = draw_structure_result(img, res,
args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name,
'show_{}.jpg'.format(index))
elif structure_sys.mode == 'vqa':
raise NotImplementedError
# draw_img = draw_ser_results(img, res, args.vis_font_path)
# img_save_path = os.path.join(save_folder, img_name + '.jpg')
if res != []:
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
if args.recovery and res != []:
from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
all_res += res
if args.recovery and all_res != []:
try:
convert_info_docx(img, all_res, save_folder, img_name,
args.save_pdf)
except Exception as ex:
logger.error(
"error in layout recovery image:{}, err msg: {}".format(
image_file, ex))
continue
logger.info("Predict time : {:.3f}s".format(time_dict['all'])) logger.info("Predict time : {:.3f}s".format(time_dict['all']))
......
...@@ -78,9 +78,27 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar ...@@ -78,9 +78,27 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# Download the ultra-lightweight English table inch model and unzip it # Download the ultra-lightweight English table inch model and unzip it
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
# Download the layout model of publaynet dataset and unzip it
wget
https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd .. cd ..
# run # run
python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png python3 predict_system.py \
--image_dir=./docs/table/1.png \
--det_model_dir=inference/en_PP-OCRv3_det_infer \
--rec_model_dir=inference/en_PP-OCRv3_rec_infe \
--rec_char_dict_path=../ppocr/utils/en_dict.txt \
--output=../output/ \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--table_max_len=488 \
--layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
--layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--recovery=True \
--save_pdf=False
``` ```
After running, the docx of each picture will be saved in the directory specified by the output field After running, the docx of each picture will be saved in the directory specified by the output field
\ No newline at end of file
Recovery table to Word code[table_process.py] reference:https://github.com/pqzx/html2docx.git
\ No newline at end of file
...@@ -35,21 +35,15 @@ ...@@ -35,21 +35,15 @@
python3 -m pip install --upgrade pip python3 -m pip install --upgrade pip
# GPU安装 # GPU安装
python3 -m pip install "paddlepaddle-gpu>=2.2" -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle-gpu>=2.3" -i https://mirror.baidu.com/pypi/simple
# CPU安装 # CPU安装
python3 -m pip install "paddlepaddle>=2.2" -i https://mirror.baidu.com/pypi/simple python3 -m pip install "paddlepaddle>=2.3" -i https://mirror.baidu.com/pypi/simple
``` ```
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
* **(2)安装依赖**
```bash
python3 -m pip install -r ppstructure/recovery/requirements.txt
```
<a name="2.2"></a> <a name="2.2"></a>
### 2.2 安装PaddleOCR ### 2.2 安装PaddleOCR
...@@ -87,11 +81,28 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar ...@@ -87,11 +81,28 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar
# 下载英文轻量级PP-OCRv3模型的识别模型并解压 # 下载英文轻量级PP-OCRv3模型的识别模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar
# 下载超轻量级英文表格英寸模型并解压 # 下载超轻量级英文表格英寸模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar
# 下载英文版面分析模型
wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_0_layout_infer.tar && tar picodet_lcnet_x1_0_layout_infer.tar
cd .. cd ..
# 执行预测 # 执行预测
python3 predict_system.py --det_model_dir=inference/en_PP-OCRv3_det_infer --rec_model_dir=inference/en_PP-OCRv3_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --rec_char_dict_path=../ppocr/utils/en_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --output ./output/table --rec_image_shape=3,48,320 --vis_font_path=../doc/fonts/simfang.ttf --recovery=True --image_dir=./docs/table/1.png python3 predict_system.py \
--image_dir=./docs/table/1.png \
--det_model_dir=inference/en_PP-OCRv3_det_infer \
--rec_model_dir=inference/en_PP-OCRv3_rec_infe \
--rec_char_dict_path=../ppocr/utils/en_dict.txt \
--output=../output/ \
--table_model_dir=inference/ch_ppstructure_mobile_v2.0_SLANet_infer \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--table_max_len=488 \
--layout_model_dir=inference/picodet_lcnet_x1_0_layout_infer \
--layout_dict_path=../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--recovery=True \
--save_pdf=False
``` ```
运行完成后,每张图片的docx文档会保存到output字段指定的目录下 运行完成后,每张图片的docx文档会保存到`output`字段指定的目录下
表格恢复到Word代码[table_process.py]来自:https://github.com/pqzx/html2docx.git
...@@ -22,21 +22,23 @@ from docx import shared ...@@ -22,21 +22,23 @@ from docx import shared
from docx.enum.text import WD_ALIGN_PARAGRAPH from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.section import WD_SECTION from docx.enum.section import WD_SECTION
from docx.oxml.ns import qn from docx.oxml.ns import qn
from docx.enum.table import WD_TABLE_ALIGNMENT
from table_process import HtmlToDocx
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
def convert_info_docx(img, res, save_folder, img_name): def convert_info_docx(img, res, save_folder, img_name, save_pdf):
doc = Document() doc = Document()
doc.styles['Normal'].font.name = 'Times New Roman' doc.styles['Normal'].font.name = 'Times New Roman'
doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体') doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体')
doc.styles['Normal'].font.size = shared.Pt(6.5) doc.styles['Normal'].font.size = shared.Pt(6.5)
h, w, _ = img.shape
res = sorted_layout_boxes(res, w)
flag = 1 flag = 1
for i, region in enumerate(res): for i, region in enumerate(res):
img_idx = region['img_idx']
if flag == 2 and region['layout'] == 'single': if flag == 2 and region['layout'] == 'single':
section = doc.add_section(WD_SECTION.CONTINUOUS) section = doc.add_section(WD_SECTION.CONTINUOUS)
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1') section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '1')
...@@ -46,10 +48,10 @@ def convert_info_docx(img, res, save_folder, img_name): ...@@ -46,10 +48,10 @@ def convert_info_docx(img, res, save_folder, img_name):
section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2') section._sectPr.xpath('./w:cols')[0].set(qn('w:num'), '2')
flag = 2 flag = 2
if region['type'] == 'Figure': if region['type'].lower() == 'figure':
excel_save_folder = os.path.join(save_folder, img_name) excel_save_folder = os.path.join(save_folder, img_name)
img_path = os.path.join(excel_save_folder, img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox'])) '{}_{}.jpg'.format(region['bbox'], img_idx))
paragraph_pic = doc.add_paragraph() paragraph_pic = doc.add_paragraph()
paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER paragraph_pic.alignment = WD_ALIGN_PARAGRAPH.CENTER
run = paragraph_pic.add_run("") run = paragraph_pic.add_run("")
...@@ -57,40 +59,38 @@ def convert_info_docx(img, res, save_folder, img_name): ...@@ -57,40 +59,38 @@ def convert_info_docx(img, res, save_folder, img_name):
run.add_picture(img_path, width=shared.Inches(5)) run.add_picture(img_path, width=shared.Inches(5))
elif flag == 2: elif flag == 2:
run.add_picture(img_path, width=shared.Inches(2)) run.add_picture(img_path, width=shared.Inches(2))
elif region['type'] == 'Title': elif region['type'].lower() == 'title':
doc.add_heading(region['res'][0]['text']) doc.add_heading(region['res'][0]['text'])
elif region['type'] == 'Text': elif region['type'].lower() == 'table':
paragraph = doc.add_paragraph()
new_parser = HtmlToDocx()
new_parser.table_style = 'TableGrid'
table = new_parser.handle_table(html=region['res']['html'])
new_table = deepcopy(table)
new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
paragraph.add_run().element.addnext(new_table._tbl)
else:
paragraph = doc.add_paragraph() paragraph = doc.add_paragraph()
paragraph_format = paragraph.paragraph_format paragraph_format = paragraph.paragraph_format
for i, line in enumerate(region['res']): for i, line in enumerate(region['res']):
if i == 0: if i == 0:
paragraph_format.first_line_indent = shared.Inches(0.25) paragraph_format.first_line_indent = shared.Inches(0.25)
text_run = paragraph.add_run(line['text'] + ' ') text_run = paragraph.add_run(line['text'] + ' ')
text_run.font.size = shared.Pt(9) text_run.font.size = shared.Pt(10)
elif region['type'] == 'Table':
pypandoc.convert(
source=region['res']['html'],
format='html',
to='docx',
outputfile='tmp.docx')
tmp_doc = Document('tmp.docx')
paragraph = doc.add_paragraph()
table = tmp_doc.tables[0]
new_table = deepcopy(table)
new_table.style = doc.styles['Table Grid']
from docx.enum.table import WD_TABLE_ALIGNMENT
new_table.alignment = WD_TABLE_ALIGNMENT.CENTER
paragraph.add_run().element.addnext(new_table._tbl)
os.remove('tmp.docx')
else:
continue
# save to docx # save to docx
docx_path = os.path.join(save_folder, '{}.docx'.format(img_name)) docx_path = os.path.join(save_folder, '{}.docx'.format(img_name))
doc.save(docx_path) doc.save(docx_path)
logger.info('docx save to {}'.format(docx_path)) logger.info('docx save to {}'.format(docx_path))
# save to pdf
if save_pdf:
pdf = os.path.join(save_folder, '{}.pdf'.format(img_name))
from docx2pdf import convert
convert(docx_path, pdf_path)
logger.info('pdf save to {}'.format(pdf))
def sorted_layout_boxes(res, w): def sorted_layout_boxes(res, w):
""" """
......
opencv-contrib-python==4.4.0.46
pypandoc pypandoc
python-docx python-docx
\ No newline at end of file docx2pdf
fitz
PyMuPDF
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:https://github.com/pqzx/html2docx/blob/8f6695a778c68befb302e48ac0ed5201ddbd4524/htmldocx/h2d.py
"""
import re, argparse
import io, os
import urllib.request
from urllib.parse import urlparse
from html.parser import HTMLParser
import docx, docx.table
from docx import Document
from docx.shared import RGBColor, Pt, Inches
from docx.enum.text import WD_COLOR, WD_ALIGN_PARAGRAPH
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
from bs4 import BeautifulSoup
# values in inches
INDENT = 0.25
LIST_INDENT = 0.5
MAX_INDENT = 5.5 # To stop indents going off the page
# Style to use with tables. By default no style is used.
DEFAULT_TABLE_STYLE = None
# Style to use with paragraphs. By default no style is used.
DEFAULT_PARAGRAPH_STYLE = None
def get_filename_from_url(url):
return os.path.basename(urlparse(url).path)
def is_url(url):
"""
Not to be used for actually validating a url, but in our use case we only
care if it's a url or a file path, and they're pretty distinguishable
"""
parts = urlparse(url)
return all([parts.scheme, parts.netloc, parts.path])
def fetch_image(url):
"""
Attempts to fetch an image from a url.
If successful returns a bytes object, else returns None
:return:
"""
try:
with urllib.request.urlopen(url) as response:
# security flaw?
return io.BytesIO(response.read())
except urllib.error.URLError:
return None
def remove_last_occurence(ls, x):
ls.pop(len(ls) - ls[::-1].index(x) - 1)
def remove_whitespace(string, leading=False, trailing=False):
"""Remove white space from a string.
Args:
string(str): The string to remove white space from.
leading(bool, optional): Remove leading new lines when True.
trailing(bool, optional): Remove trailing new lines when False.
Returns:
str: The input string with new line characters removed and white space squashed.
Examples:
Single or multiple new line characters are replaced with space.
>>> remove_whitespace("abc\\ndef")
'abc def'
>>> remove_whitespace("abc\\n\\n\\ndef")
'abc def'
New line characters surrounded by white space are replaced with a single space.
>>> remove_whitespace("abc \\n \\n \\n def")
'abc def'
>>> remove_whitespace("abc \\n \\n \\n def")
'abc def'
Leading and trailing new lines are replaced with a single space.
>>> remove_whitespace("\\nabc")
' abc'
>>> remove_whitespace(" \\n abc")
' abc'
>>> remove_whitespace("abc\\n")
'abc '
>>> remove_whitespace("abc \\n ")
'abc '
Use ``leading=True`` to remove leading new line characters, including any surrounding
white space:
>>> remove_whitespace("\\nabc", leading=True)
'abc'
>>> remove_whitespace(" \\n abc", leading=True)
'abc'
Use ``trailing=True`` to remove trailing new line characters, including any surrounding
white space:
>>> remove_whitespace("abc \\n ", trailing=True)
'abc'
"""
# Remove any leading new line characters along with any surrounding white space
if leading:
string = re.sub(r'^\s*\n+\s*', '', string)
# Remove any trailing new line characters along with any surrounding white space
if trailing:
string = re.sub(r'\s*\n+\s*$', '', string)
# Replace new line characters and absorb any surrounding space.
string = re.sub(r'\s*\n\s*', ' ', string)
# TODO need some way to get rid of extra spaces in e.g. text <span> </span> text
return re.sub(r'\s+', ' ', string)
def delete_paragraph(paragraph):
# https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907
p = paragraph._element
p.getparent().remove(p)
p._p = p._element = None
font_styles = {
'b': 'bold',
'strong': 'bold',
'em': 'italic',
'i': 'italic',
'u': 'underline',
's': 'strike',
'sup': 'superscript',
'sub': 'subscript',
'th': 'bold',
}
font_names = {
'code': 'Courier',
'pre': 'Courier',
}
styles = {
'LIST_BULLET': 'List Bullet',
'LIST_NUMBER': 'List Number',
}
class HtmlToDocx(HTMLParser):
def __init__(self):
super().__init__()
self.options = {
'fix-html': True,
'images': True,
'tables': True,
'styles': True,
}
self.table_row_selectors = [
'table > tr',
'table > thead > tr',
'table > tbody > tr',
'table > tfoot > tr'
]
self.table_style = DEFAULT_TABLE_STYLE
self.paragraph_style = DEFAULT_PARAGRAPH_STYLE
def set_initial_attrs(self, document=None):
self.tags = {
'span': [],
'list': [],
}
if document:
self.doc = document
else:
self.doc = Document()
self.bs = self.options['fix-html'] # whether or not to clean with BeautifulSoup
self.document = self.doc
self.include_tables = True #TODO add this option back in?
self.include_images = self.options['images']
self.include_styles = self.options['styles']
self.paragraph = None
self.skip = False
self.skip_tag = None
self.instances_to_skip = 0
def copy_settings_from(self, other):
"""Copy settings from another instance of HtmlToDocx"""
self.table_style = other.table_style
self.paragraph_style = other.paragraph_style
def get_cell_html(self, soup):
# Returns string of td element with opening and closing <td> tags removed
# Cannot use find_all as it only finds element tags and does not find text which
# is not inside an element
return ' '.join([str(i) for i in soup.contents])
def add_styles_to_paragraph(self, style):
if 'text-align' in style:
align = style['text-align']
if align == 'center':
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.CENTER
elif align == 'right':
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.RIGHT
elif align == 'justify':
self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY
if 'margin-left' in style:
margin = style['margin-left']
units = re.sub(r'[0-9]+', '', margin)
margin = int(float(re.sub(r'[a-z]+', '', margin)))
if units == 'px':
self.paragraph.paragraph_format.left_indent = Inches(min(margin // 10 * INDENT, MAX_INDENT))
# TODO handle non px units
def add_styles_to_run(self, style):
if 'color' in style:
if 'rgb' in style['color']:
color = re.sub(r'[a-z()]+', '', style['color'])
colors = [int(x) for x in color.split(',')]
elif '#' in style['color']:
color = style['color'].lstrip('#')
colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
else:
colors = [0, 0, 0]
# TODO map colors to named colors (and extended colors...)
# For now set color to black to prevent crashing
self.run.font.color.rgb = RGBColor(*colors)
if 'background-color' in style:
if 'rgb' in style['background-color']:
color = color = re.sub(r'[a-z()]+', '', style['background-color'])
colors = [int(x) for x in color.split(',')]
elif '#' in style['background-color']:
color = style['background-color'].lstrip('#')
colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
else:
colors = [0, 0, 0]
# TODO map colors to named colors (and extended colors...)
# For now set color to black to prevent crashing
self.run.font.highlight_color = WD_COLOR.GRAY_25 #TODO: map colors
def apply_paragraph_style(self, style=None):
try:
if style:
self.paragraph.style = style
elif self.paragraph_style:
self.paragraph.style = self.paragraph_style
except KeyError as e:
raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e
def parse_dict_string(self, string, separator=';'):
new_string = string.replace(" ", '').split(separator)
string_dict = dict([x.split(':') for x in new_string if ':' in x])
return string_dict
def handle_li(self):
# check list stack to determine style and depth
list_depth = len(self.tags['list'])
if list_depth:
list_type = self.tags['list'][-1]
else:
list_type = 'ul' # assign unordered if no tag
if list_type == 'ol':
list_style = styles['LIST_NUMBER']
else:
list_style = styles['LIST_BULLET']
self.paragraph = self.doc.add_paragraph(style=list_style)
self.paragraph.paragraph_format.left_indent = Inches(min(list_depth * LIST_INDENT, MAX_INDENT))
self.paragraph.paragraph_format.line_spacing = 1
def add_image_to_cell(self, cell, image):
# python-docx doesn't have method yet for adding images to table cells. For now we use this
paragraph = cell.add_paragraph()
run = paragraph.add_run()
run.add_picture(image)
def handle_img(self, current_attrs):
if not self.include_images:
self.skip = True
self.skip_tag = 'img'
return
src = current_attrs['src']
# fetch image
src_is_url = is_url(src)
if src_is_url:
try:
image = fetch_image(src)
except urllib.error.URLError:
image = None
else:
image = src
# add image to doc
if image:
try:
if isinstance(self.doc, docx.document.Document):
self.doc.add_picture(image)
else:
self.add_image_to_cell(self.doc, image)
except FileNotFoundError:
image = None
if not image:
if src_is_url:
self.doc.add_paragraph("<image: %s>" % src)
else:
# avoid exposing filepaths in document
self.doc.add_paragraph("<image: %s>" % get_filename_from_url(src))
def handle_table(self, html):
"""
To handle nested tables, we will parse tables manually as follows:
Get table soup
Create docx table
Iterate over soup and fill docx table with new instances of this parser
Tell HTMLParser to ignore any tags until the corresponding closing table tag
"""
doc = Document()
table_soup = BeautifulSoup(html, 'html.parser')
rows, cols_len = self.get_table_dimensions(table_soup)
table = doc.add_table(len(rows), cols_len)
table.style = doc.styles['Table Grid']
cell_row = 0
for index, row in enumerate(rows):
cols = self.get_table_columns(row)
cell_col = 0
for col in cols:
colspan = int(col.attrs.get('colspan', 1))
rowspan = int(col.attrs.get('rowspan', 1))
cell_html = self.get_cell_html(col)
if col.name == 'th':
cell_html = "<b>%s</b>" % cell_html
docx_cell = table.cell(cell_row, cell_col)
while docx_cell.text != '': # Skip the merged cell
cell_col += 1
docx_cell = table.cell(cell_row, cell_col)
cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1)
if docx_cell != cell_to_merge:
docx_cell.merge(cell_to_merge)
child_parser = HtmlToDocx()
child_parser.copy_settings_from(self)
child_parser.add_html_to_cell(cell_html or ' ', docx_cell) # occupy the position
cell_col += colspan
cell_row += 1
# skip all tags until corresponding closing tag
self.instances_to_skip = len(table_soup.find_all('table'))
self.skip_tag = 'table'
self.skip = True
self.table = None
return table
def handle_link(self, href, text):
# Link requires a relationship
is_external = href.startswith('http')
rel_id = self.paragraph.part.relate_to(
href,
docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK,
is_external=True # don't support anchor links for this library yet
)
# Create the w:hyperlink tag and add needed values
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
hyperlink.set(docx.oxml.shared.qn('r:id'), rel_id)
# Create sub-run
subrun = self.paragraph.add_run()
rPr = docx.oxml.shared.OxmlElement('w:rPr')
# add default color
c = docx.oxml.shared.OxmlElement('w:color')
c.set(docx.oxml.shared.qn('w:val'), "0000EE")
rPr.append(c)
# add underline
u = docx.oxml.shared.OxmlElement('w:u')
u.set(docx.oxml.shared.qn('w:val'), 'single')
rPr.append(u)
subrun._r.append(rPr)
subrun._r.text = text
# Add subrun to hyperlink
hyperlink.append(subrun._r)
# Add hyperlink to run
self.paragraph._p.append(hyperlink)
def handle_starttag(self, tag, attrs):
if self.skip:
return
if tag == 'head':
self.skip = True
self.skip_tag = tag
self.instances_to_skip = 0
return
elif tag == 'body':
return
current_attrs = dict(attrs)
if tag == 'span':
self.tags['span'].append(current_attrs)
return
elif tag == 'ol' or tag == 'ul':
self.tags['list'].append(tag)
return # don't apply styles for now
elif tag == 'br':
self.run.add_break()
return
self.tags[tag] = current_attrs
if tag in ['p', 'pre']:
self.paragraph = self.doc.add_paragraph()
self.apply_paragraph_style()
elif tag == 'li':
self.handle_li()
elif tag == "hr":
# This implementation was taken from:
# https://github.com/python-openxml/python-docx/issues/105#issuecomment-62806373
self.paragraph = self.doc.add_paragraph()
pPr = self.paragraph._p.get_or_add_pPr()
pBdr = OxmlElement('w:pBdr')
pPr.insert_element_before(pBdr,
'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', 'w:wordWrap',
'w:overflowPunct', 'w:topLinePunct', 'w:autoSpaceDE', 'w:autoSpaceDN',
'w:bidi', 'w:adjustRightInd', 'w:snapToGrid', 'w:spacing', 'w:ind',
'w:contextualSpacing', 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc',
'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap',
'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr',
'w:pPrChange'
)
bottom = OxmlElement('w:bottom')
bottom.set(qn('w:val'), 'single')
bottom.set(qn('w:sz'), '6')
bottom.set(qn('w:space'), '1')
bottom.set(qn('w:color'), 'auto')
pBdr.append(bottom)
elif re.match('h[1-9]', tag):
if isinstance(self.doc, docx.document.Document):
h_size = int(tag[1])
self.paragraph = self.doc.add_heading(level=min(h_size, 9))
else:
self.paragraph = self.doc.add_paragraph()
elif tag == 'img':
self.handle_img(current_attrs)
return
elif tag == 'table':
self.handle_table()
return
# set new run reference point in case of leading line breaks
if tag in ['p', 'li', 'pre']:
self.run = self.paragraph.add_run()
# add style
if not self.include_styles:
return
if 'style' in current_attrs and self.paragraph:
style = self.parse_dict_string(current_attrs['style'])
self.add_styles_to_paragraph(style)
def handle_endtag(self, tag):
if self.skip:
if not tag == self.skip_tag:
return
if self.instances_to_skip > 0:
self.instances_to_skip -= 1
return
self.skip = False
self.skip_tag = None
self.paragraph = None
if tag == 'span':
if self.tags['span']:
self.tags['span'].pop()
return
elif tag == 'ol' or tag == 'ul':
remove_last_occurence(self.tags['list'], tag)
return
elif tag == 'table':
self.table_no += 1
self.table = None
self.doc = self.document
self.paragraph = None
if tag in self.tags:
self.tags.pop(tag)
# maybe set relevant reference to None?
def handle_data(self, data):
if self.skip:
return
# Only remove white space if we're not in a pre block.
if 'pre' not in self.tags:
# remove leading and trailing whitespace in all instances
data = remove_whitespace(data, True, True)
if not self.paragraph:
self.paragraph = self.doc.add_paragraph()
self.apply_paragraph_style()
# There can only be one nested link in a valid html document
# You cannot have interactive content in an A tag, this includes links
# https://html.spec.whatwg.org/#interactive-content
link = self.tags.get('a')
if link:
self.handle_link(link['href'], data)
else:
# If there's a link, dont put the data directly in the run
self.run = self.paragraph.add_run(data)
spans = self.tags['span']
for span in spans:
if 'style' in span:
style = self.parse_dict_string(span['style'])
self.add_styles_to_run(style)
# add font style and name
for tag in self.tags:
if tag in font_styles:
font_style = font_styles[tag]
setattr(self.run.font, font_style, True)
if tag in font_names:
font_name = font_names[tag]
self.run.font.name = font_name
def ignore_nested_tables(self, tables_soup):
"""
Returns array containing only the highest level tables
Operates on the assumption that bs4 returns child elements immediately after
the parent element in `find_all`. If this changes in the future, this method will need to be updated
:return:
"""
new_tables = []
nest = 0
for table in tables_soup:
if nest:
nest -= 1
continue
new_tables.append(table)
nest = len(table.find_all('table'))
return new_tables
def get_table_rows(self, table_soup):
# If there's a header, body, footer or direct child tr tags, add row dimensions from there
return table_soup.select(', '.join(self.table_row_selectors), recursive=False)
def get_table_columns(self, row):
# Get all columns for the specified row tag.
return row.find_all(['th', 'td'], recursive=False) if row else []
def get_table_dimensions(self, table_soup):
# Get rows for the table
rows = self.get_table_rows(table_soup)
# Table is either empty or has non-direct children between table and tr tags
# Thus the row dimensions and column dimensions are assumed to be 0
cols = self.get_table_columns(rows[0]) if rows else []
# Add colspan calculation column number
col_count = 0
for col in cols:
colspan = col.attrs.get('colspan', 1)
col_count += int(colspan)
# return len(rows), col_count
return rows, col_count
def get_tables(self):
if not hasattr(self, 'soup'):
self.include_tables = False
return
# find other way to do it, or require this dependency?
self.tables = self.ignore_nested_tables(self.soup.find_all('table'))
self.table_no = 0
def run_process(self, html):
if self.bs and BeautifulSoup:
self.soup = BeautifulSoup(html, 'html.parser')
html = str(self.soup)
if self.include_tables:
self.get_tables()
self.feed(html)
def add_html_to_document(self, html, document):
if not isinstance(html, str):
raise ValueError('First argument needs to be a %s' % str)
elif not isinstance(document, docx.document.Document) and not isinstance(document, docx.table._Cell):
raise ValueError('Second argument needs to be a %s' % docx.document.Document)
self.set_initial_attrs(document)
self.run_process(html)
def add_html_to_cell(self, html, cell):
self.set_initial_attrs(cell)
self.run_process(html)
def parse_html_file(self, filename_html, filename_docx=None):
with open(filename_html, 'r') as infile:
html = infile.read()
self.set_initial_attrs()
self.run_process(html)
if not filename_docx:
path, filename = os.path.split(filename_html)
filename_docx = '%s/new_docx_file_%s' % (path, filename)
self.doc.save('%s.docx' % filename_docx)
def parse_html_string(self, html):
self.set_initial_attrs()
self.run_process(html)
return self.doc
\ No newline at end of file
...@@ -89,6 +89,11 @@ def init_args(): ...@@ -89,6 +89,11 @@ def init_args():
type=bool, type=bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
parser.add_argument(
"--save_pdf",
type=bool,
default=False,
help='Whether to save pdf file')
return parser return parser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册