提交 c703a589 编写于 作者: 文幕地方's avatar 文幕地方

fix re infer bug

上级 99de0353
...@@ -833,10 +833,11 @@ class VQATokenLabelEncode(object): ...@@ -833,10 +833,11 @@ class VQATokenLabelEncode(object):
segment_offset_id = [] segment_offset_id = []
gt_label_list = [] gt_label_list = []
if self.contains_re:
# for re
entities = [] entities = []
if not self.infer_mode:
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = [] relations = []
id2label = {} id2label = {}
entity_id_to_index_map = {} entity_id_to_index_map = {}
...@@ -845,7 +846,7 @@ class VQATokenLabelEncode(object): ...@@ -845,7 +846,7 @@ class VQATokenLabelEncode(object):
data['ocr_info'] = copy.deepcopy(ocr_info) data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info: for info in ocr_info:
if self.contains_re and not self.infer_mode: if train_re:
# for re # for re
if len(info["text"]) == 0: if len(info["text"]) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
...@@ -872,8 +873,7 @@ class VQATokenLabelEncode(object): ...@@ -872,8 +873,7 @@ class VQATokenLabelEncode(object):
gt_label = self._parse_label(label, encode_res) gt_label = self._parse_label(label, encode_res)
# construct entities for re # construct entities for re
if self.contains_re: if train_re:
if not self.infer_mode:
if gt_label[0] != self.label2id_map["O"]: if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities) entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper() label = label.upper()
...@@ -886,8 +886,7 @@ class VQATokenLabelEncode(object): ...@@ -886,8 +886,7 @@ class VQATokenLabelEncode(object):
else: else:
entities.append({ entities.append({
"start": len(input_ids_list), "start": len(input_ids_list),
"end": "end": len(input_ids_list) + len(encode_res["input_ids"]),
len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O', "label": 'O',
}) })
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
...@@ -908,12 +907,9 @@ class VQATokenLabelEncode(object): ...@@ -908,12 +907,9 @@ class VQATokenLabelEncode(object):
padding_side=self.tokenizer.padding_side, padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id, pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id) pad_token_id=self.tokenizer.pad_token_id)
if self.contains_re:
data['entities'] = entities data['entities'] = entities
if self.infer_mode:
data['ocr_info'] = ocr_info if train_re:
else:
data['relations'] = relations data['relations'] = relations
data['id2label'] = id2label data['id2label'] = id2label
data['empty_entity'] = empty_entity data['empty_entity'] = empty_entity
...@@ -921,6 +917,13 @@ class VQATokenLabelEncode(object): ...@@ -921,6 +917,13 @@ class VQATokenLabelEncode(object):
return data return data
def _load_ocr_info(self, data): def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode: if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = [] ocr_info = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册