# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import argparse import cv2 import random import numpy as np import imghdr from copy import deepcopy import paddle from PIL import Image, ImageDraw, ImageFont def set_seed(seed): random.seed(seed) np.random.seed(seed) paddle.seed(seed) def get_bio_label_maps(label_map_path): with open(label_map_path, "r", encoding='utf-8') as fin: lines = fin.readlines() lines = [line.strip() for line in lines] if "O" not in lines: lines.insert(0, "O") labels = [] for line in lines: if line == "O": labels.append("O") else: labels.append("B-" + line) labels.append("I-" + line) label2id_map = {label: idx for idx, label in enumerate(labels)} id2label_map = {idx: label for idx, label in enumerate(labels)} return label2id_map, id2label_map def get_image_file_list(img_file): imgs_lists = [] if img_file is None or not os.path.exists(img_file): raise Exception("not found any img file in {}".format(img_file)) img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'} if os.path.isfile(img_file) and imghdr.what(img_file) in img_end: imgs_lists.append(img_file) elif os.path.isdir(img_file): for single_file in os.listdir(img_file): file_path = os.path.join(img_file, single_file) if os.path.isfile(file_path) and imghdr.what(file_path) in img_end: imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) imgs_lists = sorted(imgs_lists) return imgs_lists def draw_ser_results(image, ocr_results, font_path="../../doc/fonts/simfang.ttf", font_size=18): np.random.seed(2021) color = (np.random.permutation(range(255)), np.random.permutation(range(255)), np.random.permutation(range(255))) color_map = { idx: (color[0][idx], color[1][idx], color[2][idx]) for idx in range(1, 255) } if isinstance(image, np.ndarray): image = Image.fromarray(image) img_new = image.copy() draw = ImageDraw.Draw(img_new) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") for ocr_info in ocr_results: if ocr_info["pred_id"] not in color_map: continue color = color_map[ocr_info["pred_id"]] text = "{}: {}".format(ocr_info["pred"], ocr_info["text"]) draw_box_txt(ocr_info["bbox"], text, draw, font, font_size, color) img_new = Image.blend(image, img_new, 0.5) return np.array(img_new) def draw_box_txt(bbox, text, draw, font, font_size, color): # draw ocr results outline bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3])) draw.rectangle(bbox, fill=color) # draw ocr results start_y = max(0, bbox[0][1] - font_size) tw = font.getsize(text)[0] draw.rectangle( [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + font_size)], fill=(0, 0, 255)) draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font) def draw_re_results(image, result, font_path="../../doc/fonts/simfang.ttf", font_size=18): np.random.seed(0) if isinstance(image, np.ndarray): image = Image.fromarray(image) img_new = image.copy() draw = ImageDraw.Draw(img_new) font = ImageFont.truetype(font_path, font_size, encoding="utf-8") color_head = (0, 0, 255) color_tail = (255, 0, 0) color_line = (0, 255, 0) for ocr_info_head, ocr_info_tail in result: draw_box_txt(ocr_info_head["bbox"], ocr_info_head["text"], draw, font, font_size, color_head) draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["text"], draw, font, font_size, color_tail) center_head = ( (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2, (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2) center_tail = ( (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2, (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2) draw.line([center_head, center_tail], fill=color_line, width=5) img_new = Image.blend(image, img_new, 0.5) return np.array(img_new) # pad sentences def pad_sentences(tokenizer, encoded_inputs, max_seq_len=512, pad_to_max_seq_len=True, return_attention_mask=True, return_token_type_ids=True, return_overflowing_tokens=False, return_special_tokens_mask=False): # Padding with larger size, reshape is carried out max_seq_len = ( len(encoded_inputs["input_ids"]) // max_seq_len + 1) * max_seq_len needs_to_be_padded = pad_to_max_seq_len and \ max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len if needs_to_be_padded: difference = max_seq_len - len(encoded_inputs["input_ids"]) if tokenizer.padding_side == 'right': if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ "input_ids"]) + [0] * difference if return_token_type_ids: encoded_inputs["token_type_ids"] = ( encoded_inputs["token_type_ids"] + [tokenizer.pad_token_type_id] * difference) if return_special_tokens_mask: encoded_inputs["special_tokens_mask"] = encoded_inputs[ "special_tokens_mask"] + [1] * difference encoded_inputs["input_ids"] = encoded_inputs[ "input_ids"] + [tokenizer.pad_token_id] * difference encoded_inputs["bbox"] = encoded_inputs["bbox"] + [[0, 0, 0, 0] ] * difference else: if return_attention_mask: encoded_inputs["attention_mask"] = [1] * len(encoded_inputs[ "input_ids"]) return encoded_inputs def split_page(encoded_inputs, max_seq_len=512): """ truncate is often used in training process """ for key in encoded_inputs: if key == 'entities': encoded_inputs[key] = [encoded_inputs[key]] continue encoded_inputs[key] = paddle.to_tensor(encoded_inputs[key]) if encoded_inputs[key].ndim <= 1: # for input_ids, att_mask and so on encoded_inputs[key] = encoded_inputs[key].reshape([-1, max_seq_len]) else: # for bbox encoded_inputs[key] = encoded_inputs[key].reshape( [-1, max_seq_len, 4]) return encoded_inputs def preprocess( tokenizer, ori_img, ocr_info, img_size=(224, 224), pad_token_label_id=-100, max_seq_len=512, add_special_ids=False, return_attention_mask=True, ): ocr_info = deepcopy(ocr_info) height = ori_img.shape[0] width = ori_img.shape[1] img = cv2.resize(ori_img, img_size).transpose([2, 0, 1]).astype(np.float32) segment_offset_id = [] words_list = [] bbox_list = [] input_ids_list = [] token_type_ids_list = [] entities = [] for info in ocr_info: # x1, y1, x2, y2 bbox = info["bbox"] bbox[0] = int(bbox[0] * 1000.0 / width) bbox[2] = int(bbox[2] * 1000.0 / width) bbox[1] = int(bbox[1] * 1000.0 / height) bbox[3] = int(bbox[3] * 1000.0 / height) text = info["text"] encode_res = tokenizer.encode( text, pad_to_max_seq_len=False, return_attention_mask=True) if not add_special_ids: # TODO: use tok.all_special_ids to remove encode_res["input_ids"] = encode_res["input_ids"][1:-1] encode_res["token_type_ids"] = encode_res["token_type_ids"][1:-1] encode_res["attention_mask"] = encode_res["attention_mask"][1:-1] # for re entities.append({ "start": len(input_ids_list), "end": len(input_ids_list) + len(encode_res["input_ids"]), "label": "O", }) input_ids_list.extend(encode_res["input_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"]) bbox_list.extend([bbox] * len(encode_res["input_ids"])) words_list.append(text) segment_offset_id.append(len(input_ids_list)) encoded_inputs = { "input_ids": input_ids_list, "token_type_ids": token_type_ids_list, "bbox": bbox_list, "attention_mask": [1] * len(input_ids_list), "entities": entities } encoded_inputs = pad_sentences( tokenizer, encoded_inputs, max_seq_len=max_seq_len, return_attention_mask=return_attention_mask) encoded_inputs = split_page(encoded_inputs) fake_bs = encoded_inputs["input_ids"].shape[0] encoded_inputs["image"] = paddle.to_tensor(img).unsqueeze(0).expand( [fake_bs] + list(img.shape)) encoded_inputs["segment_offset_id"] = segment_offset_id return encoded_inputs def postprocess(attention_mask, preds, id2label_map): if isinstance(preds, paddle.Tensor): preds = preds.numpy() preds = np.argmax(preds, axis=2) preds_list = [[] for _ in range(preds.shape[0])] # keep batch info for i in range(preds.shape[0]): for j in range(preds.shape[1]): if attention_mask[i][j] == 1: preds_list[i].append(id2label_map[preds[i][j]]) return preds_list def merge_preds_list_with_ocr_info(ocr_info, segment_offset_id, preds_list, label2id_map_for_draw): # must ensure the preds_list is generated from the same image preds = [p for pred in preds_list for p in pred] id2label_map = dict() for key in label2id_map_for_draw: val = label2id_map_for_draw[key] if key == "O": id2label_map[val] = key if key.startswith("B-") or key.startswith("I-"): id2label_map[val] = key[2:] else: id2label_map[val] = key for idx in range(len(segment_offset_id)): if idx == 0: start_id = 0 else: start_id = segment_offset_id[idx - 1] end_id = segment_offset_id[idx] curr_pred = preds[start_id:end_id] curr_pred = [label2id_map_for_draw[p] for p in curr_pred] if len(curr_pred) <= 0: pred_id = 0 else: counts = np.bincount(curr_pred) pred_id = np.argmax(counts) ocr_info[idx]["pred_id"] = int(pred_id) ocr_info[idx]["pred"] = id2label_map[int(pred_id)] return ocr_info def print_arguments(args, logger=None): print_func = logger.info if logger is not None else print """print arguments""" print_func('----------- Configuration Arguments -----------') for arg, value in sorted(vars(args).items()): print_func('%s: %s' % (arg, value)) print_func('------------------------------------------------') def parse_args(): parser = argparse.ArgumentParser() # Required parameters # yapf: disable parser.add_argument("--model_name_or_path", default=None, type=str, required=True,) parser.add_argument("--ser_model_type", default='LayoutXLM', type=str) parser.add_argument("--re_model_name_or_path", default=None, type=str, required=False,) parser.add_argument("--train_data_dir", default=None, type=str, required=False,) parser.add_argument("--train_label_path", default=None, type=str, required=False,) parser.add_argument("--eval_data_dir", default=None, type=str, required=False,) parser.add_argument("--eval_label_path", default=None, type=str, required=False,) parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--evaluate_during_training", action="store_true",) parser.add_argument("--num_workers", default=8, type=int,) parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, help="Batch size per GPU/CPU for training.",) parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, help="Batch size per GPU/CPU for eval.",) parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.",) parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.",) parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.",) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.",) parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.",) parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.",) parser.add_argument("--eval_steps", type=int, default=10, help="eval every X updates steps.",) parser.add_argument("--seed", type=int, default=2048, help="random seed for initialization",) parser.add_argument("--rec_model_dir", default=None, type=str, ) parser.add_argument("--det_model_dir", default=None, type=str, ) parser.add_argument( "--label_map_path", default="./labels/labels_ser.txt", type=str, required=False, ) parser.add_argument("--infer_imgs", default=None, type=str, required=False) parser.add_argument("--resume", action='store_true') parser.add_argument("--ocr_json_path", default=None, type=str, required=False, help="ocr prediction results") # yapf: enable args = parser.parse_args() return args