提交 92f85521 编写于 作者: 文幕地方's avatar 文幕地方

add "<td></td>" to dict when "<td></td>" not in file

上级 fb9be201
......@@ -58,6 +58,7 @@ Loss:
PostProcess:
name: TableLabelDecode
merge_no_span_structure: &merge_no_span_structure False
Metric:
name: TableMetric
......@@ -77,7 +78,7 @@ Train:
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
......@@ -112,7 +113,7 @@ Eval:
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
......
......@@ -8,7 +8,7 @@ Global:
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
......@@ -61,6 +61,7 @@ Loss:
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
merge_no_span_structure: &merge_no_span_structure True
Metric:
name: TableMetric
......@@ -79,7 +80,7 @@ Train:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
......@@ -115,7 +116,7 @@ Eval:
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True
loc_reg_num: *loc_reg_num
max_text_length: *max_text_length
......
......@@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode):
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
if self.merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
......
......@@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
def __init__(self,
character_dict_path,
merge_no_span_structure=False,
**kwargs):
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
if merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
self.td_token = ['<td>', '<td', '<td></td>']
def __call__(self, preds, batch=None):
......@@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode):
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
def __init__(self,
character_dict_path,
box_shape='ori',
merge_no_span_structure=True,
**kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path,
merge_no_span_structure)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册