未验证 提交 7c812814 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #6763 from WenmuZhou/xlm

[VQA] Add support for dygraph2static  for the ser model of the LayoutLM series in VQA
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2048 seed: 2048
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
Architecture: Architecture:
...@@ -21,7 +21,7 @@ Architecture: ...@@ -21,7 +21,7 @@ Architecture:
Backbone: Backbone:
name: LayoutLMv2ForRe name: LayoutLMv2ForRe
pretrained: True pretrained: True
checkpoints: checkpoints:
Loss: Loss:
name: LossFromOutput name: LossFromOutput
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re/
Architecture: Architecture:
...@@ -52,7 +52,7 @@ Train: ...@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/xfun_normalize_train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -61,7 +61,7 @@ Train: ...@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -90,7 +90,7 @@ Eval: ...@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- train_data/XFUND/zh_val/xfun_normalize_val.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -114,7 +114,7 @@ Eval: ...@@ -114,7 +114,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser/
Architecture: Architecture:
...@@ -77,7 +77,7 @@ Train: ...@@ -77,7 +77,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -112,7 +112,7 @@ Eval: ...@@ -112,7 +112,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -11,7 +11,7 @@ Global: ...@@ -11,7 +11,7 @@ Global:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: doc/vqa/input/zh_val_0.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser/
Architecture: Architecture:
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -43,7 +43,7 @@ Optimizer: ...@@ -43,7 +43,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ppstructure/vqa/labels/labels_ser.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -113,7 +113,7 @@ Eval: ...@@ -113,7 +113,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -912,7 +912,7 @@ class VQATokenLabelEncode(object): ...@@ -912,7 +912,7 @@ class VQATokenLabelEncode(object):
label = info['label'] label = info['label']
gt_label = self._parse_label(label, encode_res) gt_label = self._parse_label(label, encode_res)
# construct entities for re # construct entities for re
if train_re: if train_re:
if gt_label[0] != self.label2id_map["O"]: if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities) entity_id_to_index_map[info["id"]] = len(entities)
......
...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer): ...@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch): def forward(self, predicts, batch):
labels = batch[1] labels = batch[5]
attention_mask = batch[4] attention_mask = batch[2]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape( active_outputs = predicts.reshape(
......
...@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
attention_mask=x[4], attention_mask=x[2],
token_type_ids=x[5], token_type_ids=x[3],
position_ids=None, position_ids=None,
output_hidden_states=False) output_hidden_states=False)
return x return x
...@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel): ...@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel): ...@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[2], bbox=x[1],
image=x[3], attention_mask=x[2],
attention_mask=x[4], token_type_ids=x[3],
token_type_ids=x[5], image=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training:
return x
return x[0] return x[0]
...@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel): ...@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel): ...@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
labels=None, attention_mask=x[2],
image=x[2], token_type_ids=x[3],
attention_mask=x[3], image=x[4],
token_type_ids=x[4],
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None,
entities=x[5], entities=x[5],
relations=x[6]) relations=x[6])
return x return x
...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs): def __call__(self, preds, batch=None, *args, **kwargs):
if isinstance(preds, tuple):
preds = preds[0]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
if batch is not None: if batch is not None:
return self._metric(preds, batch[1]) return self._metric(preds, batch[5])
else: else:
return self._infer(preds, **kwargs) return self._infer(preds, **kwargs)
...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]]) j]])
return decode_out_list, label_decode_out_list return decode_out_list, label_decode_out_list
def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos): def _infer(self, preds, segment_offset_ids, ocr_infos):
results = [] results = []
for pred, attention_mask, segment_offset_id, ocr_info in zip( for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
preds, attention_masks, segment_offset_ids, ocr_infos): ocr_infos):
pred = np.argmax(pred, axis=1) pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred] pred = [self.id2label_map[idx] for idx in pred]
......
# PP-Structure 系列模型列表 # PP-Structure 系列模型列表
- [1. 版面分析模型](#1) - [1. 版面分析模型](#1-版面分析模型)
- [2. OCR和表格识别模型](#2) - [2. OCR和表格识别模型](#2-ocr和表格识别模型)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 表格识别模型](#22) - [2.2 表格识别模型](#22-表格识别模型)
- [3. VQA模型](#3) - [3. VQA模型](#3-vqa模型)
- [4. KIE模型](#4) - [4. KIE模型](#4-kie模型)
<a name="1"></a> <a name="1"></a>
...@@ -42,11 +42,11 @@ ...@@ -42,11 +42,11 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE模型 ## 4. KIE模型
......
# PP-Structure Model list # PP-Structure Model list
- [1. Layout Analysis](#1) - [1. Layout Analysis](#1-layout-analysis)
- [2. OCR and Table Recognition](#2) - [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
- [2.1 OCR](#21) - [2.1 OCR](#21-ocr)
- [2.2 Table Recognition](#22) - [2.2 Table Recognition](#22-table-recognition)
- [3. VQA](#3) - [3. VQA](#3-vqa)
- [4. KIE](#4) - [4. KIE](#4-kie)
<a name="1"></a> <a name="1"></a>
...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model ...@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download| |model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- | | --- |----------------------------------------------------------------| --- | --- |
|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | |re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | |ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
<a name="4"></a> <a name="4"></a>
## 4. KIE ## 4. KIE
......
...@@ -40,6 +40,13 @@ def init_args(): ...@@ -40,6 +40,13 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for vqa
parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
parser.add_argument("--ser_model_dir", type=str)
parser.add_argument(
"--ser_dict_path",
type=str,
default="../train_data/XFUND/class_list_xfun.txt")
# params for inference # params for inference
parser.add_argument( parser.add_argument(
"--mode", "--mode",
...@@ -65,7 +72,7 @@ def init_args(): ...@@ -65,7 +72,7 @@ def init_args():
"--recovery", "--recovery",
type=bool, type=bool,
default=False, default=False,
help='Whether to enable layout of recovery') help='Whether to enable layout of recovery')
return parser return parser
......
English | [简体中文](README_ch.md) English | [简体中文](README_ch.md)
- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering) - [1 Introduction](#1-introduction)
- [1. Introduction](#1-Introduction) - [2. Performance](#2-performance)
- [2. Performance](#2-performance) - [3. Effect demo](#3-effect-demo)
- [3. Effect demo](#3-Effect-demo) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. Install](#4-install)
- [4. Install](#4-Install) - [4.1 Install dependencies](#41-install-dependencies)
- [4.1 Installation dependencies](#41-Install-dependencies) - [5.3 RE](#53-re)
- [4.2 Install PaddleOCR](#42-Install-PaddleOCR) - [6. Reference Links](#6-reference-links)
- [5. Usage](#5-Usage) - [License](#license)
- [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- [5.2 SER](#52-ser)
- [5.3 RE](#53-re)
- [6. Reference](#6-Reference-Links)
# Document Visual Question Answering # Document Visual Question Answering
...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```` ````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed Finally, `precision`, `recall`, `hmean` and other indicators will be printed
* Use `OCR engine + SER` tandem prediction * `OCR + SER` tandem prediction based on training engine
Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example: Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* End-to-end evaluation of `OCR engine + SER` prediction system * End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate. First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o ...@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```` ````
* export model
Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
* `OCR + SER` tandem prediction based on prediction engine
Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
<a name="53"></a> <a name="53"></a>
### 5.3 RE ### 5.3 RE
...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed ...@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example: Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```` ````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`. Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
* export model
cooming soon
* `OCR + SER + RE` tandem prediction based on prediction engine
cooming soon
## 6. Reference Links ## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
[English](README.md) | 简体中文 [English](README.md) | 简体中文
- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa) - [1. 简介](#1-简介)
- [1. 简介](#1-简介) - [2. 性能](#2-性能)
- [2. 性能](#2-性能) - [3. 效果演示](#3-效果演示)
- [3. 效果演示](#3-效果演示) - [3.1 SER](#31-ser)
- [3.1 SER](#31-ser) - [3.2 RE](#32-re)
- [3.2 RE](#32-re) - [4. 安装](#4-安装)
- [4. 安装](#4-安装) - [4.1 安装依赖](#41-安装依赖)
- [4.1 安装依赖](#41-安装依赖) - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa) - [5. 使用](#5-使用)
- [5. 使用](#5-使用) - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- [5.1 数据和预训练模型准备](#51-数据和预训练模型准备) - [5.2 SER](#52-ser)
- [5.2 SER](#52-ser) - [5.3 RE](#53-re)
- [5.3 RE](#53-re) - [6. 参考链接](#6-参考链接)
- [6. 参考链接](#6-参考链接) - [License](#license)
# 文档视觉问答(DOC-VQA) # 文档视觉问答(DOC-VQA)
...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o ...@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER`串联预测 * 基于训练引擎的`OCR + SER`串联预测
使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell ```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
*`OCR引擎 + SER`预测系统进行端到端评估 *`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。 首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l ...@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
``` ```
* 模型导出
使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
```shell
python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
```
转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
* 基于预测引擎的`OCR + SER`串联预测
使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
cd ppstructure
CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE ### 5.3 RE
...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o ...@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
``` ```
最终会打印出`precision`, `recall`, `hmean`等指标 最终会打印出`precision`, `recall`, `hmean`等指标
* 使用`OCR引擎 + SER + RE`串联预测 * 基于训练引擎的`OCR + SER + RE`串联预测
使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例: 使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell ```shell
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
``` ```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt` 最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`
* 模型导出
cooming soon
* 基于预测引擎的`OCR + SER + RE`串联预测
cooming soon
## 6. 参考链接 ## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf - LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import json
import numpy as np
import time
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppstructure.utility import parse_args
from paddleocr import PaddleOCR
logger = get_logger()
class SerPredictor(object):
def __init__(self, args):
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
pre_process_list = [{
'VQATokenLabelEncode': {
'algorithm': args.vqa_algorithm,
'class_path': args.ser_dict_path,
'contains_re': False,
'ocr_engine': self.ocr_engine
}
}, {
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
'std': [58.395, 57.12, 57.375],
'mean': [123.675, 116.28, 103.53],
'scale': '1',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
}]
postprocess_params = {
'name': 'VQASerTokenLayoutLMPostProcess',
"class_path": args.ser_dict_path,
}
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'ser', logger)
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img = data[0]
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
img = img.copy()
starttime = time.time()
for idx in range(len(self.input_tensor)):
expand_input = np.expand_dims(data[idx], axis=0)
self.input_tensor[idx].copy_from_cpu(expand_input)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
post_result = self.postprocess_op(
preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
elapse = time.time() - starttime
return post_result, elapse
def main(args):
image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerPredictor(args)
count = 0
total_time = 0
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
img = img[:, :, ::-1]
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
ser_res, elapse = ser_predictor(img)
ser_res = ser_res[0]
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": ser_res,
}, ensure_ascii=False))
f_w.write(res_str)
img_res = draw_ser_results(
image_file,
ser_res,
font_path="../doc/fonts/simfang.ttf", )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
main(parse_args())
sentencepiece sentencepiece
yacs yacs
seqeval seqeval
paddlenlp>=2.2.1 paddlenlp>=2.2.1
\ No newline at end of file pypandoc
attrdict
python_docx
\ No newline at end of file
...@@ -97,6 +97,22 @@ def export_single_model(model, ...@@ -97,6 +97,22 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"), shape=[None, 1, 32, 100], dtype="float32"),
] ]
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
input_spec = [
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
paddle.static.InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
paddle.static.InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype="int64"), # image
]
if arch_config["algorithm"] == "LayoutLM":
input_spec.pop(4)
model = to_static(model, input_spec=[input_spec])
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec": if arch_config["model_type"] == "rec":
...@@ -172,7 +188,7 @@ def main(): ...@@ -172,7 +188,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"]) model = build_model(config["Architecture"])
load_model(config, model) load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval() model.eval()
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
......
...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger): ...@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'table': elif mode == 'table':
model_dir = args.table_model_dir model_dir = args.table_model_dir
elif mode == 'ser':
model_dir = args.ser_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger): ...@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
for name in input_names: if mode in ['ser', 're']:
input_tensor = predictor.get_input_handle(name) input_tensor = []
for name in input_names:
input_tensor.append(predictor.get_input_handle(name))
else:
for name in input_names:
input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor) output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config return predictor, input_tensor, output_tensors, config
......
...@@ -44,6 +44,7 @@ def to_tensor(data): ...@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict from collections import defaultdict
data_dict = defaultdict(list) data_dict = defaultdict(list)
to_tensor_idxs = [] to_tensor_idxs = []
for idx, v in enumerate(data): for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs: if idx not in to_tensor_idxs:
...@@ -57,6 +58,7 @@ def to_tensor(data): ...@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object): class SerPredictor(object):
def __init__(self, config): def __init__(self, config):
global_config = config['Global'] global_config = config['Global']
self.algorithm = config['Architecture']["algorithm"]
# build post process # build post process
self.post_process_class = build_post_process(config['PostProcess'], self.post_process_class = build_post_process(config['PostProcess'],
...@@ -70,7 +72,10 @@ class SerPredictor(object): ...@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) self.ocr_engine = PaddleOCR(
use_angle_cls=False,
show_log=False,
use_gpu=global_config['use_gpu'])
# create data ops # create data ops
transforms = [] transforms = []
...@@ -80,8 +85,8 @@ class SerPredictor(object): ...@@ -80,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [ op[op_name]['keep_keys'] = [
'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'token_type_ids', 'segment_offset_id', 'ocr_info', 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities' 'entities'
] ]
...@@ -99,11 +104,11 @@ class SerPredictor(object): ...@@ -99,11 +104,11 @@ class SerPredictor(object):
batch = transform(data, self.ops) batch = transform(data, self.ops)
batch = to_tensor(batch) batch = to_tensor(batch)
preds = self.model(batch) preds = self.model(batch)
if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
preds = preds[0]
post_result = self.post_process_class( post_result = self.post_process_class(
preds, preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
attention_masks=batch[4],
segment_offset_ids=batch[6],
ocr_infos=batch[7])
return post_result, batch return post_result, batch
...@@ -138,8 +143,6 @@ if __name__ == '__main__': ...@@ -138,8 +143,6 @@ if __name__ == '__main__':
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(data) result, _ = ser_engine(data)
result = result[0] result = result[0]
...@@ -149,3 +152,6 @@ if __name__ == '__main__': ...@@ -149,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result) img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model ...@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
from tools.program import ArgsParser, load_config, merge_config, check_gpu from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor from tools.infer_vqa_token_ser import SerPredictor
...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results): ...@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input # remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7) ser_inputs.pop(7)
ser_inputs.pop(6) ser_inputs.pop(6)
ser_inputs.pop(1) ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch return ser_inputs, entity_idx_dict_batch
...@@ -131,9 +131,7 @@ class SerRePredictor(object): ...@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine(img_path) ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
paddle.save(ser_inputs, 'ser_inputs.npy')
paddle.save(ser_results, 'ser_results.npy')
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input) preds = self.model(re_input)
post_result = self.post_process_class( post_result = self.post_process_class(
...@@ -155,7 +153,6 @@ def preprocess(): ...@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device) device = paddle.set_device(device)
...@@ -185,9 +182,7 @@ if __name__ == '__main__': ...@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs): for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
result = ser_re_engine(img_path) result = ser_re_engine(img_path)
result = result[0] result = result[0]
...@@ -197,3 +192,6 @@ if __name__ == '__main__': ...@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n") }, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result) img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res) cv2.imwrite(save_img_path, img_res)
logger.info("process: [{}/{}], save result to {}".format(
idx, len(infer_imgs), save_img_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册