提交 92f85521 编写于 作者: 文幕地方's avatar 文幕地方

add "<td></td>" to dict when "<td></td>" not in file

上级 fb9be201
...@@ -58,6 +58,7 @@ Loss: ...@@ -58,6 +58,7 @@ Loss:
PostProcess: PostProcess:
name: TableLabelDecode name: TableLabelDecode
merge_no_span_structure: &merge_no_span_structure False
Metric: Metric:
name: TableMetric name: TableMetric
...@@ -77,7 +78,7 @@ Train: ...@@ -77,7 +78,7 @@ Train:
channel_first: False channel_first: False
- TableLabelEncode: - TableLabelEncode:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: False merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False replace_empty_cell_token: False
loc_reg_num: *loc_reg_num loc_reg_num: *loc_reg_num
max_text_length: *max_text_length max_text_length: *max_text_length
...@@ -112,7 +113,7 @@ Eval: ...@@ -112,7 +113,7 @@ Eval:
channel_first: False channel_first: False
- TableLabelEncode: - TableLabelEncode:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: False merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: False replace_empty_cell_token: False
loc_reg_num: *loc_reg_num loc_reg_num: *loc_reg_num
max_text_length: *max_text_length max_text_length: *max_text_length
......
...@@ -61,6 +61,7 @@ Loss: ...@@ -61,6 +61,7 @@ Loss:
PostProcess: PostProcess:
name: TableMasterLabelDecode name: TableMasterLabelDecode
box_shape: pad box_shape: pad
merge_no_span_structure: &merge_no_span_structure True
Metric: Metric:
name: TableMetric name: TableMetric
...@@ -79,7 +80,7 @@ Train: ...@@ -79,7 +80,7 @@ Train:
channel_first: False channel_first: False
- TableMasterLabelEncode: - TableMasterLabelEncode:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: True merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True replace_empty_cell_token: True
loc_reg_num: *loc_reg_num loc_reg_num: *loc_reg_num
max_text_length: *max_text_length max_text_length: *max_text_length
...@@ -115,7 +116,7 @@ Eval: ...@@ -115,7 +116,7 @@ Eval:
channel_first: False channel_first: False
- TableMasterLabelEncode: - TableMasterLabelEncode:
learn_empty_box: False learn_empty_box: False
merge_no_span_structure: True merge_no_span_structure: *merge_no_span_structure
replace_empty_cell_token: True replace_empty_cell_token: True
loc_reg_num: *loc_reg_num loc_reg_num: *loc_reg_num
max_text_length: *max_text_length max_text_length: *max_text_length
......
...@@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode): ...@@ -587,6 +587,12 @@ class TableLabelEncode(AttnLabelEncode):
line = line.decode('utf-8').strip("\n").strip("\r\n") line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line) dict_character.append(line)
if self.merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character) dict_character = self.add_special_char(dict_character)
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
......
...@@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode ...@@ -21,8 +21,28 @@ from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode): class TableLabelDecode(AttnLabelDecode):
""" """ """ """
def __init__(self, character_dict_path, **kwargs): def __init__(self,
super(TableLabelDecode, self).__init__(character_dict_path) character_dict_path,
merge_no_span_structure=False,
**kwargs):
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
if merge_no_span_structure:
if "<td></td>" not in dict_character:
dict_character.append("<td></td>")
if "<td>" in dict_character:
dict_character.remove("<td>")
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
self.td_token = ['<td>', '<td', '<td></td>'] self.td_token = ['<td>', '<td', '<td></td>']
def __call__(self, preds, batch=None): def __call__(self, preds, batch=None):
...@@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode): ...@@ -122,8 +142,13 @@ class TableLabelDecode(AttnLabelDecode):
class TableMasterLabelDecode(TableLabelDecode): class TableMasterLabelDecode(TableLabelDecode):
""" """ """ """
def __init__(self, character_dict_path, box_shape='ori', **kwargs): def __init__(self,
super(TableMasterLabelDecode, self).__init__(character_dict_path) character_dict_path,
box_shape='ori',
merge_no_span_structure=True,
**kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path,
merge_no_span_structure)
self.box_shape = box_shape self.box_shape = box_shape
assert box_shape in [ assert box_shape in [
'ori', 'pad' 'ori', 'pad'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册