diff --git a/configs/vqa/re/layoutlmv2.yml b/configs/vqa/re/layoutlmv2.yml
index 2fa5fd1165c20bbfa8d8505bbb53d48744daebef..737dbf6b600b1b414a7f66f422e59f46154d91a9 100644
--- a/configs/vqa/re/layoutlmv2.yml
+++ b/configs/vqa/re/layoutlmv2.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2048
- infer_img: doc/vqa/input/zh_val_21.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
@@ -21,7 +21,7 @@ Architecture:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
- checkpoints:
+ checkpoints:
Loss:
name: LossFromOutput
@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/re/layoutxlm.yml b/configs/vqa/re/layoutxlm.yml
index ff16120ac1be92e989ebfda6af3ccf346dde89cd..d8585bb72593d55578ff3c6cd1401b5a843bb683 100644
--- a/configs/vqa/re/layoutxlm.yml
+++ b/configs/vqa/re/layoutxlm.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_21.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
Architecture:
@@ -52,7 +52,7 @@ Train:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- - train_data/XFUND/zh_train/xfun_normalize_train.json
+ - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
@@ -61,7 +61,7 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -90,7 +90,7 @@ Eval:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- - train_data/XFUND/zh_val/xfun_normalize_val.json
+ - train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
@@ -114,7 +114,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'image', 'attention_mask', 'token_type_ids','entities', 'relations'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutlm.yml b/configs/vqa/ser/layoutlm.yml
index 47ab093e1fce5cb38a75409eb1d9ac67c6426ba4..53e114defd4cdfa427ae27b647603744302eb0e8 100644
--- a/configs/vqa/ser/layoutlm.yml
+++ b/configs/vqa/ser/layoutlm.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_0.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/
Architecture:
@@ -77,7 +77,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -112,7 +112,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutlmv2.yml b/configs/vqa/ser/layoutlmv2.yml
index d6a9c03e5ec9683a3a6423ed22a98f361769541f..e48c7469567a740ca74240f0ca9f782ed5bb3c6d 100644
--- a/configs/vqa/ser/layoutlmv2.yml
+++ b/configs/vqa/ser/layoutlmv2.yml
@@ -11,7 +11,7 @@ Global:
save_inference_dir:
use_visualdl: False
seed: 2022
- infer_img: doc/vqa/input/zh_val_0.jpg
+ infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/
Architecture:
@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/configs/vqa/ser/layoutxlm.yml b/configs/vqa/ser/layoutxlm.yml
index 3686989ccf7481a28584fd49c3969c1a69cd04d4..fa9df192afbc1d638c220cba3ef3640715585b37 100644
--- a/configs/vqa/ser/layoutxlm.yml
+++ b/configs/vqa/ser/layoutxlm.yml
@@ -43,7 +43,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
- class_path: &class_path ppstructure/vqa/labels/labels_ser.txt
+ class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
@@ -78,7 +78,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids','labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
@@ -113,7 +113,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', 'token_type_ids'] # dataloader will return list in this order
+ keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index c95b326224515c43b03f90ee51c809006399dfff..0723e97ae719690ef2e6a500b327b039c7a46f66 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -912,7 +912,7 @@ class VQATokenLabelEncode(object):
label = info['label']
gt_label = self._parse_label(label, encode_res)
-# construct entities for re
+ # construct entities for re
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index 244893d97d0e422c5ca270bdece689e13aba2b07..f9cd4634731a26dd990d6ffac3d8defc8cdf7e97 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -27,8 +27,8 @@ class VQASerTokenLayoutLMLoss(nn.Layer):
self.ignore_index = self.loss_class.ignore_index
def forward(self, predicts, batch):
- labels = batch[1]
- attention_mask = batch[4]
+ labels = batch[5]
+ attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index ede5b7a35af65fac351277cefccd89b251f5cdb7..2fd1b1b2a78a98dba1930378f4a06783aadd8834 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -74,9 +74,9 @@ class LayoutLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
position_ids=None,
output_hidden_states=False)
return x
@@ -96,13 +96,15 @@ class LayoutLMv2ForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -120,13 +122,15 @@ class LayoutXLMForSer(NLPBaseModel):
def forward(self, x):
x = self.model(
input_ids=x[0],
- bbox=x[2],
- image=x[3],
- attention_mask=x[4],
- token_type_ids=x[5],
+ bbox=x[1],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
labels=None)
+ if not self.training:
+ return x
return x[0]
@@ -140,12 +144,12 @@ class LayoutLMv2ForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
@@ -161,12 +165,12 @@ class LayoutXLMForRe(NLPBaseModel):
x = self.model(
input_ids=x[0],
bbox=x[1],
- labels=None,
- image=x[2],
- attention_mask=x[3],
- token_type_ids=x[4],
+ attention_mask=x[2],
+ token_type_ids=x[3],
+ image=x[4],
position_ids=None,
head_mask=None,
+ labels=None,
entities=x[5],
relations=x[6])
return x
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 782cdea6c58c69e0d728787e0e21e200c9e13790..8a6669f71f5ae6a7a16931e565b43355de5928d9 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -41,11 +41,13 @@ class VQASerTokenLayoutLMPostProcess(object):
self.id2label_map_for_show[val] = key
def __call__(self, preds, batch=None, *args, **kwargs):
+ if isinstance(preds, tuple):
+ preds = preds[0]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if batch is not None:
- return self._metric(preds, batch[1])
+ return self._metric(preds, batch[5])
else:
return self._infer(preds, **kwargs)
@@ -63,11 +65,11 @@ class VQASerTokenLayoutLMPostProcess(object):
j]])
return decode_out_list, label_decode_out_list
- def _infer(self, preds, attention_masks, segment_offset_ids, ocr_infos):
+ def _infer(self, preds, segment_offset_ids, ocr_infos):
results = []
- for pred, attention_mask, segment_offset_id, ocr_info in zip(
- preds, attention_masks, segment_offset_ids, ocr_infos):
+ for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
+ ocr_infos):
pred = np.argmax(pred, axis=1)
pred = [self.id2label_map[idx] for idx in pred]
diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md
index c7dab999ff6e370c56c5495e22e91f117b3d1275..dabce3a5149a88833d38a4395e31ac1f82306c4f 100644
--- a/ppstructure/docs/models_list.md
+++ b/ppstructure/docs/models_list.md
@@ -1,11 +1,11 @@
# PP-Structure 系列模型列表
-- [1. 版面分析模型](#1)
-- [2. OCR和表格识别模型](#2)
- - [2.1 OCR](#21)
- - [2.2 表格识别模型](#22)
-- [3. VQA模型](#3)
-- [4. KIE模型](#4)
+- [1. 版面分析模型](#1-版面分析模型)
+- [2. OCR和表格识别模型](#2-ocr和表格识别模型)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 表格识别模型](#22-表格识别模型)
+- [3. VQA模型](#3-vqa模型)
+- [4. KIE模型](#4-kie模型)
@@ -42,11 +42,11 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的SER模型|1.4G|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh|基于LayoutXLM在xfun中文数据集上训练的RE模型|1.4G|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的SER模型|778M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh|基于LayoutLM在xfun中文数据集上训练的SER模型|430M|[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE模型
diff --git a/ppstructure/docs/models_list_en.md b/ppstructure/docs/models_list_en.md
index b92c10c241df72c85649b64f915b4266cd3fe410..e133a0bb2a9b017207b5e92ea444aba4633a7457 100644
--- a/ppstructure/docs/models_list_en.md
+++ b/ppstructure/docs/models_list_en.md
@@ -1,11 +1,11 @@
# PP-Structure Model list
-- [1. Layout Analysis](#1)
-- [2. OCR and Table Recognition](#2)
- - [2.1 OCR](#21)
- - [2.2 Table Recognition](#22)
-- [3. VQA](#3)
-- [4. KIE](#4)
+- [1. Layout Analysis](#1-layout-analysis)
+- [2. OCR and Table Recognition](#2-ocr-and-table-recognition)
+ - [2.1 OCR](#21-ocr)
+ - [2.2 Table Recognition](#22-table-recognition)
+- [3. VQA](#3-vqa)
+- [4. KIE](#4-kie)
@@ -42,11 +42,11 @@ If you need to use other OCR models, you can download the model in [PP-OCR model
|model| description |inference model size|download|
| --- |----------------------------------------------------------------| --- | --- |
-|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
-|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
-|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
+|ser_LayoutXLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) |
+|re_LayoutXLM_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLM |1.4G|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) |
+|ser_LayoutLMv2_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutXLMv2 |778M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) |
|re_LayoutLMv2_xfun_zh| Re model trained on xfun Chinese dataset based on LayoutXLMv2 |765M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) |
-|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model coming soon]() / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
+|ser_LayoutLM_xfun_zh| SER model trained on xfun Chinese dataset based on LayoutLM |430M|[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) |
## 4. KIE
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 1ad902e7e6be95a6901e3774420fad337f594861..4ae56099b83a46c85ce2dc362c1c6417b324dbe1 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -40,6 +40,13 @@ def init_args():
type=ast.literal_eval,
default=None,
help='label map according to ppstructure/layout/README_ch.md')
+ # params for vqa
+ parser.add_argument("--vqa_algorithm", type=str, default='LayoutXLM')
+ parser.add_argument("--ser_model_dir", type=str)
+ parser.add_argument(
+ "--ser_dict_path",
+ type=str,
+ default="../train_data/XFUND/class_list_xfun.txt")
# params for inference
parser.add_argument(
"--mode",
@@ -65,7 +72,7 @@ def init_args():
"--recovery",
type=bool,
default=False,
- help='Whether to enable layout of recovery')
+ help='Whether to enable layout of recovery')
return parser
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 711ffa313865cd5a210143819cd4604dc28ef4f4..05635265b5e5eff18429e2d595fc4195381299f5 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -1,19 +1,15 @@
English | [简体中文](README_ch.md)
-- [Document Visual Question Answering (Doc-VQA)](#Document-Visual-Question-Answering)
- - [1. Introduction](#1-Introduction)
- - [2. Performance](#2-performance)
- - [3. Effect demo](#3-Effect-demo)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. Install](#4-Install)
- - [4.1 Installation dependencies](#41-Install-dependencies)
- - [4.2 Install PaddleOCR](#42-Install-PaddleOCR)
- - [5. Usage](#5-Usage)
- - [5.1 Data and Model Preparation](#51-Data-and-Model-Preparation)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. Reference](#6-Reference-Links)
+- [1 Introduction](#1-introduction)
+- [2. Performance](#2-performance)
+- [3. Effect demo](#3-effect-demo)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. Install](#4-install)
+ - [4.1 Install dependencies](#41-install-dependencies)
+ - [5.3 RE](#53-re)
+- [6. Reference Links](#6-reference-links)
+- [License](#license)
# Document Visual Question Answering
@@ -187,17 +183,17 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
````
Finally, `precision`, `recall`, `hmean` and other indicators will be printed
-* Use `OCR engine + SER` tandem prediction
+* `OCR + SER` tandem prediction based on training engine
-Use the following command to complete the series prediction of `OCR engine + SER`, taking the pretrained SER model as an example:
+Use the following command to complete the series prediction of `OCR engine + SER`, taking the SER model based on LayoutXLM as an example::
```shell
-CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_42.jpg
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
-* End-to-end evaluation of `OCR engine + SER` prediction system
+* End-to-end evaluation of `OCR + SER` prediction system
First use the `tools/infer_vqa_token_ser.py` script to complete the prediction of the dataset, then use the following command to evaluate.
@@ -205,6 +201,24 @@ First use the `tools/infer_vqa_token_ser.py` script to complete the prediction o
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
````
+* export model
+
+Use the following command to complete the model export of the SER model, taking the SER model based on LayoutXLM as an example:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+The converted model will be stored in the directory specified by the `Global.save_inference_dir` field.
+
+* `OCR + SER` tandem prediction based on prediction engine
+
+Use the following command to complete the tandem prediction of `OCR + SER` based on the prediction engine, taking the SER model based on LayoutXLM as an example:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
### 5.3 RE
@@ -247,11 +261,19 @@ Finally, `precision`, `recall`, `hmean` and other indicators will be printed
Use the following command to complete the series prediction of `OCR engine + SER + RE`, taking the pretrained SER and RE models as an example:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm. yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
````
Finally, the prediction result visualization image and the prediction result text file will be saved in the directory configured by the `config.Global.save_res_path` field. The prediction result text file is named `infer_results.txt`.
+* export model
+
+cooming soon
+
+* `OCR + SER + RE` tandem prediction based on prediction engine
+
+cooming soon
+
## 6. Reference Links
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md
index 297ba64f82e70eafd4a0b1fee0764899799219ad..b421a82d3a1cbe39f5c740bea486ec26593ab20f 100644
--- a/ppstructure/vqa/README_ch.md
+++ b/ppstructure/vqa/README_ch.md
@@ -1,19 +1,19 @@
[English](README.md) | 简体中文
-- [文档视觉问答(DOC-VQA)](#文档视觉问答doc-vqa)
- - [1. 简介](#1-简介)
- - [2. 性能](#2-性能)
- - [3. 效果演示](#3-效果演示)
- - [3.1 SER](#31-ser)
- - [3.2 RE](#32-re)
- - [4. 安装](#4-安装)
- - [4.1 安装依赖](#41-安装依赖)
- - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
- - [5. 使用](#5-使用)
- - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
- - [5.2 SER](#52-ser)
- - [5.3 RE](#53-re)
- - [6. 参考链接](#6-参考链接)
+- [1. 简介](#1-简介)
+- [2. 性能](#2-性能)
+- [3. 效果演示](#3-效果演示)
+ - [3.1 SER](#31-ser)
+ - [3.2 RE](#32-re)
+- [4. 安装](#4-安装)
+ - [4.1 安装依赖](#41-安装依赖)
+ - [4.2 安装PaddleOCR(包含 PP-OCR 和 VQA)](#42-安装paddleocr包含-pp-ocr-和-vqa)
+- [5. 使用](#5-使用)
+ - [5.1 数据和预训练模型准备](#51-数据和预训练模型准备)
+ - [5.2 SER](#52-ser)
+ - [5.3 RE](#53-re)
+- [6. 参考链接](#6-参考链接)
+- [License](#license)
# 文档视觉问答(DOC-VQA)
@@ -183,16 +183,16 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/ser/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER`串联预测
+* 基于训练引擎的`OCR + SER`串联预测
-使用如下命令即可完成`OCR引擎 + SER`的串联预测, 以SER预训练模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
```shell
CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_42.jpg
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
-* 对`OCR引擎 + SER`预测系统进行端到端评估
+* 对`OCR + SER`预测系统进行端到端评估
首先使用 `tools/infer_vqa_token_ser.py` 脚本完成数据集的预测,然后使用下面的命令进行评估。
@@ -200,6 +200,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/infer_vqa_token_ser.py -c configs/vqa/ser/l
export CUDA_VISIBLE_DEVICES=0
python3 tools/eval_with_label_end2end.py --gt_json_path XFUND/zh_val/xfun_normalize_val.json --pred_json_path output_res/infer_results.txt
```
+* 模型导出
+
+使用如下命令即可完成SER模型的模型导出, 以基于LayoutXLM的SER模型为例:
+
+```shell
+python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/ Global.save_inference_dir=output/ser/infer
+```
+转换后的模型会存放在`Global.save_inference_dir`字段指定的目录下。
+
+* 基于预测引擎的`OCR + SER`串联预测
+
+使用如下命令即可完成基于预测引擎的`OCR + SER`的串联预测, 以基于LayoutXLM的SER模型为例:
+
+```shell
+cd ppstructure
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+```
+预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
### 5.3 RE
@@ -236,16 +254,24 @@ CUDA_VISIBLE_DEVICES=0 python3 tools/eval.py -c configs/vqa/re/layoutxlm.yml -o
```
最终会打印出`precision`, `recall`, `hmean`等指标
-* 使用`OCR引擎 + SER + RE`串联预测
+* 基于训练引擎的`OCR + SER + RE`串联预测
-使用如下命令即可完成`OCR引擎 + SER + RE`的串联预测, 以预训练SER和RE模型为例:
+使用如下命令即可完成基于训练引擎的`OCR + SER + RE`串联预测, 以基于LayoutXLMSER和RE模型为例:
```shell
export CUDA_VISIBLE_DEVICES=0
-python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=doc/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
+python3 tools/infer_vqa_token_ser_re.py -c configs/vqa/re/layoutxlm.yml -o Architecture.Backbone.checkpoints=pretrain/re_LayoutXLM_xfun_zh/ Global.infer_img=ppstructure/docs/vqa/input/zh_val_21.jpg -c_ser configs/vqa/ser/layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=pretrain/ser_LayoutXLM_xfun_zh/
```
最终会在`config.Global.save_res_path`字段所配置的目录下保存预测结果可视化图像以及预测结果文本文件,预测结果文本文件名为`infer_results.txt`。
+* 模型导出
+
+cooming soon
+
+* 基于预测引擎的`OCR + SER + RE`串联预测
+
+cooming soon
+
## 6. 参考链接
- LayoutXLM: Multimodal Pre-training for Multilingual Visually-rich Document Understanding, https://arxiv.org/pdf/2104.08836.pdf
diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py
new file mode 100644
index 0000000000000000000000000000000000000000..de0bbfe72d80d9a16de8b09657a98dc5285bb348
--- /dev/null
+++ b/ppstructure/vqa/predict_vqa_token_ser.py
@@ -0,0 +1,169 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import json
+import numpy as np
+import time
+
+import tools.infer.utility as utility
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+from ppocr.utils.logging import get_logger
+from ppocr.utils.visual import draw_ser_results
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppstructure.utility import parse_args
+
+from paddleocr import PaddleOCR
+
+logger = get_logger()
+
+
+class SerPredictor(object):
+ def __init__(self, args):
+ self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+
+ pre_process_list = [{
+ 'VQATokenLabelEncode': {
+ 'algorithm': args.vqa_algorithm,
+ 'class_path': args.ser_dict_path,
+ 'contains_re': False,
+ 'ocr_engine': self.ocr_engine
+ }
+ }, {
+ 'VQATokenPad': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'VQASerTokenChunk': {
+ 'max_seq_len': 512,
+ 'return_attention_mask': True
+ }
+ }, {
+ 'Resize': {
+ 'size': [224, 224]
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [58.395, 57.12, 57.375],
+ 'mean': [123.675, 116.28, 103.53],
+ 'scale': '1',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': [
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
+ 'entities'
+ ]
+ }
+ }]
+ postprocess_params = {
+ 'name': 'VQASerTokenLayoutLMPostProcess',
+ "class_path": args.ser_dict_path,
+ }
+
+ self.preprocess_op = create_operators(pre_process_list,
+ {'infer_mode': True})
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 'ser', logger)
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ for idx in range(len(self.input_tensor)):
+ expand_input = np.expand_dims(data[idx], axis=0)
+ self.input_tensor[idx].copy_from_cpu(expand_input)
+
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
+ post_result = self.postprocess_op(
+ preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]])
+ elapse = time.time() - starttime
+ return post_result, elapse
+
+
+def main(args):
+ image_file_list = get_image_file_list(args.image_dir)
+ ser_predictor = SerPredictor(args)
+ count = 0
+ total_time = 0
+
+ os.makedirs(args.output, exist_ok=True)
+ with open(
+ os.path.join(args.output, 'infer.txt'), mode='w',
+ encoding='utf-8') as f_w:
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ img = img[:, :, ::-1]
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ ser_res, elapse = ser_predictor(img)
+ ser_res = ser_res[0]
+
+ res_str = '{}\t{}\n'.format(
+ image_file,
+ json.dumps(
+ {
+ "ocr_info": ser_res,
+ }, ensure_ascii=False))
+ f_w.write(res_str)
+
+ img_res = draw_ser_results(
+ image_file,
+ ser_res,
+ font_path="../doc/fonts/simfang.ttf", )
+
+ img_save_path = os.path.join(args.output,
+ os.path.basename(image_file))
+ cv2.imwrite(img_save_path, img_res)
+ logger.info("save vis result to {}".format(img_save_path))
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+
+
+if __name__ == "__main__":
+ main(parse_args())
diff --git a/ppstructure/vqa/requirements.txt b/ppstructure/vqa/requirements.txt
index 0042ec0baedcc3e7bbecb922d10b93c95219219d..fcd882274c4402ba2a1d34f20ee6e2befa157121 100644
--- a/ppstructure/vqa/requirements.txt
+++ b/ppstructure/vqa/requirements.txt
@@ -1,4 +1,7 @@
sentencepiece
yacs
seqeval
-paddlenlp>=2.2.1
\ No newline at end of file
+paddlenlp>=2.2.1
+pypandoc
+attrdict
+python_docx
\ No newline at end of file
diff --git a/tools/export_model.py b/tools/export_model.py
index b10d41d5b288258ad895cefa7d8cc243eff10546..65573cf46a9d650b8f833fdec43235de57faf5ac 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -97,6 +97,22 @@ def export_single_model(model,
shape=[None, 1, 32, 100], dtype="float32"),
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
+ input_spec = [
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # input_ids
+ paddle.static.InputSpec(
+ shape=[None, 512, 4], dtype="int64"), # bbox
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # attention_mask
+ paddle.static.InputSpec(
+ shape=[None, 512], dtype="int64"), # token_type_ids
+ paddle.static.InputSpec(
+ shape=[None, 3, 224, 224], dtype="int64"), # image
+ ]
+ if arch_config["algorithm"] == "LayoutLM":
+ input_spec.pop(4)
+ model = to_static(model, input_spec=[input_spec])
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
@@ -172,7 +188,7 @@ def main():
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
- load_model(config, model)
+ load_model(config, model, model_type=config['Architecture']["model_type"])
model.eval()
save_path = config["Global"]["save_inference_dir"]
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 366212f228eec33f11c825bfaf1e360258af9b2e..7eb77dec74bf283936e1143edcb5b5dfc28365bd 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -153,6 +153,8 @@ def create_predictor(args, mode, logger):
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
+ elif mode == 'ser':
+ model_dir = args.ser_model_dir
else:
model_dir = args.e2e_model_dir
@@ -316,8 +318,13 @@ def create_predictor(args, mode, logger):
# create predictor
predictor = inference.create_predictor(config)
input_names = predictor.get_input_names()
- for name in input_names:
- input_tensor = predictor.get_input_handle(name)
+ if mode in ['ser', 're']:
+ input_tensor = []
+ for name in input_names:
+ input_tensor.append(predictor.get_input_handle(name))
+ else:
+ for name in input_names:
+ input_tensor = predictor.get_input_handle(name)
output_tensors = get_output_tensors(args, mode, predictor)
return predictor, input_tensor, output_tensors, config
diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py
index 39ada64a99847a910158b74672c89398ba08f032..0173a554cace31e20ab47dbe36d132a4dbb2127b 100755
--- a/tools/infer_vqa_token_ser.py
+++ b/tools/infer_vqa_token_ser.py
@@ -44,6 +44,7 @@ def to_tensor(data):
from collections import defaultdict
data_dict = defaultdict(list)
to_tensor_idxs = []
+
for idx, v in enumerate(data):
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)):
if idx not in to_tensor_idxs:
@@ -57,6 +58,7 @@ def to_tensor(data):
class SerPredictor(object):
def __init__(self, config):
global_config = config['Global']
+ self.algorithm = config['Architecture']["algorithm"]
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
@@ -70,7 +72,10 @@ class SerPredictor(object):
from paddleocr import PaddleOCR
- self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False)
+ self.ocr_engine = PaddleOCR(
+ use_angle_cls=False,
+ show_log=False,
+ use_gpu=global_config['use_gpu'])
# create data ops
transforms = []
@@ -80,8 +85,8 @@ class SerPredictor(object):
op[op_name]['ocr_engine'] = self.ocr_engine
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = [
- 'input_ids', 'labels', 'bbox', 'image', 'attention_mask',
- 'token_type_ids', 'segment_offset_id', 'ocr_info',
+ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
+ 'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
@@ -99,11 +104,11 @@ class SerPredictor(object):
batch = transform(data, self.ops)
batch = to_tensor(batch)
preds = self.model(batch)
+ if self.algorithm in ['LayoutLMv2', 'LayoutXLM']:
+ preds = preds[0]
+
post_result = self.post_process_class(
- preds,
- attention_masks=batch[4],
- segment_offset_ids=batch[6],
- ocr_infos=batch[7])
+ preds, segment_offset_ids=batch[6], ocr_infos=batch[7])
return post_result, batch
@@ -138,8 +143,6 @@ if __name__ == '__main__':
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
result, _ = ser_engine(data)
result = result[0]
@@ -149,3 +152,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_ser_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
index 6210f7f3c24227c9d366b08ce93ccfe4df849ce1..20ab1fe176c3be75f7a7b01a8d77df6419c58c75 100755
--- a/tools/infer_vqa_token_ser_re.py
+++ b/tools/infer_vqa_token_ser_re.py
@@ -38,7 +38,7 @@ from ppocr.utils.save_load import load_model
from ppocr.utils.visual import draw_re_results
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
-from tools.program import ArgsParser, load_config, merge_config, check_gpu
+from tools.program import ArgsParser, load_config, merge_config
from tools.infer_vqa_token_ser import SerPredictor
@@ -107,7 +107,7 @@ def make_input(ser_inputs, ser_results):
# remove ocr_info segment_offset_id and label in ser input
ser_inputs.pop(7)
ser_inputs.pop(6)
- ser_inputs.pop(1)
+ ser_inputs.pop(5)
return ser_inputs, entity_idx_dict_batch
@@ -131,9 +131,7 @@ class SerRePredictor(object):
self.model.eval()
def __call__(self, img_path):
- ser_results, ser_inputs = self.ser_engine(img_path)
- paddle.save(ser_inputs, 'ser_inputs.npy')
- paddle.save(ser_results, 'ser_results.npy')
+ ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
@@ -155,7 +153,6 @@ def preprocess():
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
- check_gpu(use_gpu)
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
@@ -185,9 +182,7 @@ if __name__ == '__main__':
for idx, img_path in enumerate(infer_imgs):
save_img_path = os.path.join(
config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg")
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
+ os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path)
result = result[0]
@@ -197,3 +192,6 @@ if __name__ == '__main__':
}, ensure_ascii=False) + "\n")
img_res = draw_re_results(img_path, result)
cv2.imwrite(save_img_path, img_res)
+
+ logger.info("process: [{}/{}], save result to {}".format(
+ idx, len(infer_imgs), save_img_path))