diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 46cc22d0a6205d191eeddf682ef7f6614c346402..105fb5fad287ba014a16c1138f5f9a3a25ad609f 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -58,6 +58,7 @@ Loss: PostProcess: name: TableLabelDecode + merge_no_span_structure: &merge_no_span_structure False Metric: name: TableMetric @@ -77,7 +78,7 @@ Train: channel_first: False - TableLabelEncode: learn_empty_box: False - merge_no_span_structure: False + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: False loc_reg_num: *loc_reg_num max_text_length: *max_text_length @@ -112,7 +113,7 @@ Eval: channel_first: False - TableLabelEncode: learn_empty_box: False - merge_no_span_structure: False + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: False loc_reg_num: *loc_reg_num max_text_length: *max_text_length diff --git a/configs/table/table_master.yml b/configs/table/table_master.yml index 1844c3197c7aef785a18cc38f79f6a174fa867e3..1f50d20bfa4b2f3b28a6375580529e25f24aaa2d 100755 --- a/configs/table/table_master.yml +++ b/configs/table/table_master.yml @@ -8,7 +8,7 @@ Global: eval_batch_step: [0, 6259] cal_metric_during_train: true pretrained_model: null - checkpoints: + checkpoints: save_inference_dir: output/table_master/infer use_visualdl: false infer_img: ppstructure/docs/table/table.jpg @@ -61,6 +61,7 @@ Loss: PostProcess: name: TableMasterLabelDecode box_shape: pad + merge_no_span_structure: &merge_no_span_structure True Metric: name: TableMetric @@ -79,7 +80,7 @@ Train: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True loc_reg_num: *loc_reg_num max_text_length: *max_text_length @@ -115,7 +116,7 @@ Eval: channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: *merge_no_span_structure replace_empty_cell_token: True loc_reg_num: *loc_reg_num max_text_length: *max_text_length diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 1c1fc00be755ea237d603ceab8721734b2386f5b..ce539dcea9608762f725e5a3ae501e384360d04d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode): line = line.decode('utf-8').strip("\n").strip("\r\n") dict_character.append(line) + if self.merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + dict_character = self.add_special_char(dict_character) self.dict = {} for i, char in enumerate(dict_character): diff --git a/ppocr/postprocess/table_postprocess.py b/ppocr/postprocess/table_postprocess.py index ce254f314318b1d663841c4a0ab2c4439a9a1572..a47061f935e31b24fdb624df170f8abb38e01f40 100644 --- a/ppocr/postprocess/table_postprocess.py +++ b/ppocr/postprocess/table_postprocess.py @@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode class TableLabelDecode(AttnLabelDecode): """ """ - def __init__(self, character_dict_path, **kwargs): - super(TableLabelDecode, self).__init__(character_dict_path) + def __init__(self, + character_dict_path, + merge_no_span_structure=False, + **kwargs): + dict_character = [] + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + dict_character.append(line) + + if merge_no_span_structure: + if "" not in dict_character: + dict_character.append("") + if "" in dict_character: + dict_character.remove("") + + dict_character = self.add_special_char(dict_character) + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + self.character = dict_character self.td_token = ['', ''] def __call__(self, preds, batch=None): @@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode): class TableMasterLabelDecode(TableLabelDecode): """ """ - def __init__(self, character_dict_path, box_shape='ori', **kwargs): - super(TableMasterLabelDecode, self).__init__(character_dict_path) + def __init__(self, + character_dict_path, + box_shape='ori', + merge_no_span_structure=True, + **kwargs): + super(TableMasterLabelDecode, self).__init__(character_dict_path, + merge_no_span_structure) self.box_shape = box_shape assert box_shape in [ 'ori', 'pad'