提交 c703a589 编写于 作者: 文幕地方's avatar 文幕地方

fix re infer bug

上级 99de0353
......@@ -833,19 +833,20 @@ class VQATokenLabelEncode(object):
segment_offset_id = []
gt_label_list = []
if self.contains_re:
# for re
entities = []
if not self.infer_mode:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
relations = []
id2label = {}
entity_id_to_index_map = {}
empty_entity = set()
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
if self.contains_re and not self.infer_mode:
if train_re:
# for re
if len(info["text"]) == 0:
empty_entity.add(info["id"])
......@@ -872,24 +873,22 @@ class VQATokenLabelEncode(object):
gt_label = self._parse_label(label, encode_res)
# construct entities for re
if self.contains_re:
if not self.infer_mode:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
if train_re:
if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
"label": label.upper(),
})
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
......@@ -908,19 +907,23 @@ class VQATokenLabelEncode(object):
padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities
if self.contains_re:
data['entities'] = entities
if self.infer_mode:
data['ocr_info'] = ocr_info
else:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
if train_re:
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data
def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册