From 92f85521218993f9739079fc26ccc34519f5bbf9 Mon Sep 17 00:00:00 2001
From: WenmuZhou <572459439@qq.com>
Date: Tue, 9 Aug 2022 10:50:30 +0000
Subject: [PATCH] add "
| " to dict when " | " not in file
---
configs/table/SLANet.yml | 5 ++--
configs/table/table_master.yml | 7 +++---
ppocr/data/imaug/label_ops.py | 6 +++++
ppocr/postprocess/table_postprocess.py | 33 ++++++++++++++++++++++----
4 files changed, 42 insertions(+), 9 deletions(-)
diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml
index 46cc22d0..105fb5fa 100644
--- a/configs/table/SLANet.yml
+++ b/configs/table/SLANet.yml
@@ -58,6 +58,7 @@ Loss:
PostProcess:
name: TableLabelDecode
+ merge_no_span_structure: &merge_no_span_structure False
Metric:
name: TableMetric
@@ -77,7 +78,7 @@ Train:
channel_first: False
- TableLabelEncode:
learn_empty_box: False
- merge_no_span_structure: False
+ merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
@@ -112,7 +113,7 @@ Eval:
channel_first: False
- TableLabelEncode:
learn_empty_box: False
- merge_no_span_structure: False
+ merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml
index 1844c319..1f50d20b 100755
--- a/configs/table/table_master.yml
+++ b/configs/table/table_master.yml
@@ -8,7 +8,7 @@ Global:
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
- checkpoints:
+ checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
@@ -61,6 +61,7 @@ Loss:
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
+ merge_no_span_structure: &merge_no_span_structure True
Metric:
name: TableMetric
@@ -79,7 +80,7 @@ Train:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
- merge_no_span_structure: True
+ merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
@@ -115,7 +116,7 @@ Eval:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
- merge_no_span_structure: True
+ merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 1c1fc00b..ce539dce 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode):
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
+ if self.merge_no_span_structure:
+ if " | " not in dict_character:
+ dict_character.append(" | ")
+ if "" in dict_character:
+ dict_character.remove(" | ")
+
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py
index ce254f31..a47061f9 100644
--- a/ppocr/postprocess/table_postprocess.py
+++ b/ppocr/postprocess/table_postprocess.py
@@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
- def __init__(self, character_dict_path, **kwargs):
- super(TableLabelDecode, self).__init__(character_dict_path)
+ def __init__(self,
+ character_dict_path,
+ merge_no_span_structure=False,
+ **kwargs):
+ dict_character = []
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ dict_character.append(line)
+
+ if merge_no_span_structure:
+ if " | | " not in dict_character:
+ dict_character.append(" | ")
+ if "" in dict_character:
+ dict_character.remove(" | ")
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
self.td_token = [' | ', ' | | ']
def __call__(self, preds, batch=None):
@@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode):
class TableMasterLabelDecode(TableLabelDecode):
""" """
- def __init__(self, character_dict_path, box_shape='ori', **kwargs):
- super(TableMasterLabelDecode, self).__init__(character_dict_path)
+ def __init__(self,
+ character_dict_path,
+ box_shape='ori',
+ merge_no_span_structure=True,
+ **kwargs):
+ super(TableMasterLabelDecode, self).__init__(character_dict_path,
+ merge_no_span_structure)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
--
GitLab