提交 434ab122 编写于 作者: 文幕地方's avatar 文幕地方

merge infer and train

上级 5b307b4b
...@@ -821,54 +821,44 @@ class VQATokenLabelEncode(object): ...@@ -821,54 +821,44 @@ class VQATokenLabelEncode(object):
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
def __call__(self, data): def __call__(self, data):
if self.infer_mode == False: # load bbox and label info
return self._train(data) ocr_info = self._load_ocr_info(data)
else:
return self._infer(data)
def _train(self, data):
info = data['label']
# read text info height, width, _ = data['image'].shape
info_dict = json.loads(info)
height = info_dict["height"]
width = info_dict["width"]
words_list = [] words_list = []
bbox_list = [] bbox_list = []
input_ids_list = [] input_ids_list = []
token_type_ids_list = [] token_type_ids_list = []
segment_offset_id = []
gt_label_list = [] gt_label_list = []
if self.contains_re: if self.contains_re:
# for re # for re
entities = [] entities = []
relations = [] if not self.infer_mode:
id2label = {} relations = []
entity_id_to_index_map = {} id2label = {}
empty_entity = set() entity_id_to_index_map = {}
for info in info_dict["ocr_info"]: empty_entity = set()
if self.contains_re:
data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info:
if self.contains_re and not self.infer_mode:
# for re # for re
if len(info["text"]) == 0: if len(info["text"]) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
continue continue
id2label[info["id"]] = info["label"] id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]]) relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box
# x1, y1, x2, y2 bbox = self._smooth_box(info["bbox"], height, width)
bbox = info["bbox"]
label = info["label"]
bbox[0] = int(bbox[0] * 1000.0 / width)
bbox[2] = int(bbox[2] * 1000.0 / width)
bbox[1] = int(bbox[1] * 1000.0 / height)
bbox[3] = int(bbox[3] * 1000.0 / height)
text = info["text"] text = info["text"]
encode_res = self.tokenizer.encode( encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True) text, pad_to_max_seq_len=False, return_attention_mask=True)
gt_label = []
if not self.add_special_ids: if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove # TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1] encode_res["input_ids"] = encode_res["input_ids"][1:-1]
...@@ -876,35 +866,44 @@ class VQATokenLabelEncode(object): ...@@ -876,35 +866,44 @@ class VQATokenLabelEncode(object):
-1] -1]
encode_res["attention_mask"] = encode_res["attention_mask"][1: encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1] -1]
if label.lower() == "other": # parse label
gt_label.extend([0] * len(encode_res["input_ids"])) if not self.infer_mode:
else: label = info['label']
gt_label.append(self.label2id_map[("b-" + label).upper()]) gt_label = self._parse_label(label, encode_res)
gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
(len(encode_res["input_ids"]) - 1)) # construct entities for re
if self.contains_re: if self.contains_re:
if gt_label[0] != self.label2id_map["O"]: if not self.infer_mode:
entity_id_to_index_map[info["id"]] = len(entities) if gt_label[0] != self.label2id_map["O"]:
entity_id_to_index_map[info["id"]] = len(entities)
label = label.upper()
entities.append({
"start": len(input_ids_list),
"end":
len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(),
})
else:
entities.append({ entities.append({
"start": len(input_ids_list), "start": len(input_ids_list),
"end": "end":
len(input_ids_list) + len(encode_res["input_ids"]), len(input_ids_list) + len(encode_res["input_ids"]),
"label": label.upper(), "label": 'O',
}) })
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend([bbox] * len(encode_res["input_ids"]))
gt_label_list.extend(gt_label)
words_list.append(text) words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = { if not self.infer_mode:
"input_ids": input_ids_list, gt_label_list.extend(gt_label)
"labels": gt_label_list,
"token_type_ids": token_type_ids_list, data['input_ids'] = input_ids_list
"bbox": bbox_list, data['token_type_ids'] = token_type_ids_list
"attention_mask": [1] * len(input_ids_list), data['bbox'] = bbox_list
} data['attention_mask'] = [1] * len(input_ids_list)
data.update(encoded_inputs) data['labels'] = gt_label_list
data['segment_offset_id'] = segment_offset_id
data['tokenizer_params'] = dict( data['tokenizer_params'] = dict(
padding_side=self.tokenizer.padding_side, padding_side=self.tokenizer.padding_side,
pad_token_type_id=self.tokenizer.pad_token_type_id, pad_token_type_id=self.tokenizer.pad_token_type_id,
...@@ -912,79 +911,45 @@ class VQATokenLabelEncode(object): ...@@ -912,79 +911,45 @@ class VQATokenLabelEncode(object):
if self.contains_re: if self.contains_re:
data['entities'] = entities data['entities'] = entities
data['relations'] = relations if self.infer_mode:
data['id2label'] = id2label data['ocr_info'] = ocr_info
data['empty_entity'] = empty_entity else:
data['entity_id_to_index_map'] = entity_id_to_index_map data['relations'] = relations
data['id2label'] = id2label
data['empty_entity'] = empty_entity
data['entity_id_to_index_map'] = entity_id_to_index_map
return data return data
def _infer(self, data): def _load_ocr_info(self, data):
def trans_poly_to_bbox(poly): if self.infer_mode:
x1 = np.min([p[0] for p in poly]) ocr_result = self.ocr_engine.ocr(data['image'], cls=False)
x2 = np.max([p[0] for p in poly]) ocr_info = []
y1 = np.min([p[1] for p in poly]) for res in ocr_result:
y2 = np.max([p[1] for p in poly]) ocr_info.append({
return [x1, y1, x2, y2] "text": res[1][0],
"bbox": trans_poly_to_bbox(res[0]),
height, width, _ = data['image'].shape "poly": res[0],
ocr_result = self.ocr_engine.ocr(data['image'], cls=False) })
ocr_info = [] return ocr_info
for res in ocr_result: else:
ocr_info.append({ info = data['label']
"text": res[1][0], # read text info
"bbox": trans_poly_to_bbox(res[0]), info_dict = json.loads(info)
"poly": res[0], return info_dict["ocr_info"]
})
def _smooth_box(self, bbox, height, width):
segment_offset_id = [] bbox[0] = int(bbox[0] * 1000.0 / width)
words_list = [] bbox[2] = int(bbox[2] * 1000.0 / width)
bbox_list = [] bbox[1] = int(bbox[1] * 1000.0 / height)
input_ids_list = [] bbox[3] = int(bbox[3] * 1000.0 / height)
token_type_ids_list = [] return bbox
entities = []
def _parse_label(self, label, encode_res):
for info in ocr_info: gt_label = []
# x1, y1, x2, y2 if label.lower() == "other":
bbox = copy.deepcopy(info["bbox"]) gt_label.extend([0] * len(encode_res["input_ids"]))
bbox[0] = int(bbox[0] * 1000.0 / width) else:
bbox[2] = int(bbox[2] * 1000.0 / width) gt_label.append(self.label2id_map[("b-" + label).upper()])
bbox[1] = int(bbox[1] * 1000.0 / height) gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
bbox[3] = int(bbox[3] * 1000.0 / height) (len(encode_res["input_ids"]) - 1))
return gt_label
text = info["text"]
encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
encode_res["input_ids"] = encode_res["input_ids"][1:-1]
encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
-1]
encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1]
# for re
entities.append({
"start": len(input_ids_list),
"end": len(input_ids_list) + len(encode_res["input_ids"]),
"label": "O",
})
input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"]))
words_list.append(text)
segment_offset_id.append(len(input_ids_list))
encoded_inputs = {
"input_ids": input_ids_list,
"token_type_ids": token_type_ids_list,
"bbox": bbox_list,
"attention_mask": [1] * len(input_ids_list),
"entities": entities,
'labels': None,
'segment_offset_id': segment_offset_id,
'ocr_info': ocr_info
}
data.update(encoded_inputs)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册