提交 c703a589 编写于 作者: 文幕地方's avatar 文幕地方

fix re infer bug

上级 99de0353
...@@ -833,19 +833,20 @@ class VQATokenLabelEncode(object): ...@@ -833,19 +833,20 @@ class VQATokenLabelEncode(object):
segment_offset_id = [] segment_offset_id = []
gt_label_list = [] gt_label_list = []
if self.contains_re: entities = []
# for re
entities = [] # for re
if not self.infer_mode: train_re = self.contains_re and not self.infer_mode
relations = [] if train_re:
id2label = {} relations = []
entity_id_to_index_map = {} id2label = {}
empty_entity = set() entity_id_to_index_map = {}
empty_entity = set()
data['ocr_info'] = copy.deepcopy(ocr_info) data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info: for info in ocr_info:
if self.contains_re and not self.infer_mode: if train_re:
# for re # for re
if len(info["text"]) == 0: if len(info["text"]) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
...@@ -872,24 +873,22 @@ class VQATokenLabelEncode(object): ...@@ -872,24 +873,22 @@ class VQATokenLabelEncode(object):
gt_label = self._parse_label(label, encode_res) gt_label = self._parse_label(label, encode_res)
# construct entities for re # construct entities for re
if self.contains_re: if train_re:
if not self.infer_mode: if gt_label[0] != self.label2id_map["O"]:
if gt_label[0] != self.label2id_map["O"]: entity_id_to_index_map[info["id"]] = len(entities)
entity_id_to_index_map[info["id"]] = len(entities) label = label.upper()
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
entities.append({ entities.append({
"start": len(input_ids_list), "start": len(input_ids_list),
"end": "end":
len(input_ids_list) + len(encode_res["input_ids"]), len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O', "label": label.upper(),
}) })
else:
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": 'O',
})
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend([bbox] * len(encode_res["input_ids"]))
...@@ -908,19 +907,23 @@ class VQATokenLabelEncode(object): ...@@ -908,19 +907,23 @@ class VQATokenLabelEncode(object):
padding_side=self.tokenizer.padding_side, padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id, pad_token_type_id=self.tokenizer.pad_token_type_id,
pad_token_id=self.tokenizer.pad_token_id) pad_token_id=self.tokenizer.pad_token_id)
data['entities'] = entities
if self.contains_re: if train_re:
data['entities'] = entities data['relations'] = relations
if self.infer_mode: data['id2label'] = id2label
data['ocr_info'] = ocr_info data['empty_entity'] = empty_entity
else: data['entity_id_to_index_map'] = entity_id_to_index_map
data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data return data
def _load_ocr_info(self, data): def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]
if self.infer_mode: if self.infer_mode:
ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
ocr_info = [] ocr_info = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册