From c703a5891a3f2872e381a3cc12aa18a5296703c6 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 12 Jan 2022 02:29:07 +0000 Subject: [PATCH] fix re infer bug --- ppocr/data/imaug/label_ops.py | 65 ++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 27b4aca2..786647f1 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -833,19 +833,20 @@ class VQATokenLabelEncode(object): segment_offset_id = [] gt_label_list = [] - if self.contains_re: - # for re - entities = [] - if not self.infer_mode: - relations = [] - id2label = {} - entity_id_to_index_map = {} - empty_entity = set() + entities = [] + + # for re + train_re = self.contains_re and not self.infer_mode + if train_re: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() data['ocr_info'] = copy.deepcopy(ocr_info) for info in ocr_info: - if self.contains_re and not self.infer_mode: + if train_re: # for re if len(info["text"]) == 0: empty_entity.add(info["id"]) @@ -872,24 +873,22 @@ class VQATokenLabelEncode(object): gt_label = self._parse_label(label, encode_res) # construct entities for re - if self.contains_re: - if not self.infer_mode: - if gt_label[0] != self.label2id_map["O"]: - entity_id_to_index_map[info["id"]] = len(entities) - label = label.upper() - entities.append({ - "start": len(input_ids_list), - "end": - len(input_ids_list) + len(encode_res["input_ids"]), - "label": label.upper(), - }) - else: + if train_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() entities.append({ "start": len(input_ids_list), "end": len(input_ids_list) + len(encode_res["input_ids"]), - "label": 'O', + "label": label.upper(), }) + else: + entities.append({ + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": 'O', + }) input_ids_list.extend(encode_res["input_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"]) bbox_list.extend([bbox] * len(encode_res["input_ids"])) @@ -908,19 +907,23 @@ class VQATokenLabelEncode(object): padding_side=self.tokenizer.padding_side, pad_token_type_id=self.tokenizer.pad_token_type_id, pad_token_id=self.tokenizer.pad_token_id) + data['entities'] = entities - if self.contains_re: - data['entities'] = entities - if self.infer_mode: - data['ocr_info'] = ocr_info - else: - data['relations'] = relations - data['id2label'] = id2label - data['empty_entity'] = empty_entity - data['entity_id_to_index_map'] = entity_id_to_index_map + if train_re: + data['relations'] = relations + data['id2label'] = id2label + data['empty_entity'] = empty_entity + data['entity_id_to_index_map'] = entity_id_to_index_map return data def _load_ocr_info(self, data): + def trans_poly_to_bbox(poly): + x1 = np.min([p[0] for p in poly]) + x2 = np.max([p[0] for p in poly]) + y1 = np.min([p[1] for p in poly]) + y2 = np.max([p[1] for p in poly]) + return [x1, y1, x2, y2] + if self.infer_mode: ocr_result = self.ocr_engine.ocr(data['image'], cls=False) ocr_info = [] -- GitLab