From 92f85521218993f9739079fc26ccc34519f5bbf9 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Tue, 9 Aug 2022 10:50:30 +0000 Subject: [PATCH] add "" to dict when "" not in file --- configs/table/SLANet.yml | 5 ++-- configs/table/table_master.yml | 7 +++--- ppocr/data/imaug/label_ops.py | 6 +++++ ppocr/postprocess/table_postprocess.py | 33 ++++++++++++++++++++++---- 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 46cc22d0..105fb5fa 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -58,6 +58,7 @@ Loss: PostProcess: name: TableLabelDecode + merge_no_span_structure: &merge_no_span_structure False Metric: name: TableMetric @@ -77,7 +78,7 @@ Train: channel_first: False - TableLabelEncode: learn_empty_box: False - merge_no_span_structure: False + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: False loc_reg_num: *loc_reg_num max_text_length: *max_text_length @@ -112,7 +113,7 @@ Eval: channel_first: False - TableLabelEncode: learn_empty_box: False - merge_no_span_structure: False + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: False loc_reg_num: *loc_reg_num max_text_length: *max_text_length diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml index 1844c319..1f50d20b 100755 --- a/configs/table/table_master.yml +++ b/configs/table/table_master.yml @@ -8,7 +8,7 @@ Global: eval_batch_step: [0, 6259] cal_metric_during_train: true pretrained_model: null - checkpoints: + checkpoints: save_inference_dir: output/table_master/infer use_visualdl: false infer_img: ppstructure/docs/table/table.jpg @@ -61,6 +61,7 @@ Loss: PostProcess: name: TableMasterLabelDecode box_shape: pad + merge_no_span_structure: &merge_no_span_structure True Metric: name: TableMetric @@ -79,7 +80,7 @@ Train: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True loc_reg_num: *loc_reg_num max_text_length: *max_text_length @@ -115,7 +116,7 @@ Eval: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True loc_reg_num: *loc_reg_num max_text_length: *max_text_length diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 1c1fc00b..ce539dce 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode): line = line.decode('utf-8').strip("\n").strip("\r\n") dict_character.append(line) + if self.merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py index ce254f31..a47061f9 100644 --- a/ppocr/postprocess/table_postprocess.py +++ b/ppocr/postprocess/table_postprocess.py @@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode class TableLabelDecode(AttnLabelDecode): """ """ - def __init__(self, character_dict_path, **kwargs): - super(TableLabelDecode, self).__init__(character_dict_path) + def __init__(self, + character_dict_path, + merge_no_span_structure=False, + **kwargs): + dict_character = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + dict_character.append(line) + + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character self.td_token = ['', ''] def __call__(self, preds, batch=None): @@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode): class TableMasterLabelDecode(TableLabelDecode): """ """ - def __init__(self, character_dict_path, box_shape='ori', **kwargs): - super(TableMasterLabelDecode, self).__init__(character_dict_path) + def __init__(self, + character_dict_path, + box_shape='ori', + merge_no_span_structure=True, + **kwargs): + super(TableMasterLabelDecode, self).__init__(character_dict_path, + merge_no_span_structure) self.box_shape = box_shape assert box_shape in [ 'ori', 'pad' -- GitLab