From 5161d8251eaac12395db01122cd08e7a3763939c Mon Sep 17 00:00:00 2001 From: zhiboniu Date: Tue, 21 Jun 2022 06:46:03 +0000 Subject: [PATCH] delete other rec algorithm --- deploy/pphuman/config/infer_cfg_ppvehicle.yml | 1 - deploy/pphuman/ppvehicle/vehicle_plate.py | 279 +------- .../pphuman/ppvehicle/vehicle_plateutils.py | 11 +- .../ppvehicle/vehicleplate_postprocess.py | 659 +----------------- 4 files changed, 22 insertions(+), 928 deletions(-) diff --git a/deploy/pphuman/config/infer_cfg_ppvehicle.yml b/deploy/pphuman/config/infer_cfg_ppvehicle.yml index b3aeedd5c..da1e9253d 100644 --- a/deploy/pphuman/config/infer_cfg_ppvehicle.yml +++ b/deploy/pphuman/config/infer_cfg_ppvehicle.yml @@ -19,7 +19,6 @@ VEHICLE_PLATE: det_model_dir: output_inference/ch_PP-OCRv3_det_infer/ det_limit_side_len: 480 det_limit_type: "max" - rec_algorithm: "SVTR_LCNet" rec_model_dir: output_inference/ch_PP-OCRv3_rec_infer/ rec_image_shape: [3, 48, 320] rec_batch_num: 6 diff --git a/deploy/pphuman/ppvehicle/vehicle_plate.py b/deploy/pphuman/ppvehicle/vehicle_plate.py index e8787dfe0..5e0c5af95 100644 --- a/deploy/pphuman/ppvehicle/vehicle_plate.py +++ b/deploy/pphuman/ppvehicle/vehicle_plate.py @@ -151,7 +151,6 @@ class TextRecognizer(object): def __init__(self, args, cfg, use_gpu=True): self.rec_image_shape = cfg['rec_image_shape'] self.rec_batch_num = cfg['rec_batch_num'] - self.rec_algorithm = cfg['rec_algorithm'] word_dict_path = cfg['word_dict_path'] use_space_char = True @@ -160,30 +159,6 @@ class TextRecognizer(object): "character_dict_path": word_dict_path, "use_space_char": use_space_char } - if self.rec_algorithm == "SRN": - postprocess_params = { - 'name': 'SRNLabelDecode', - "character_dict_path": word_dict_path, - "use_space_char": use_space_char - } - elif self.rec_algorithm == "RARE": - postprocess_params = { - 'name': 'AttnLabelDecode', - "character_dict_path": word_dict_path, - "use_space_char": use_space_char - } - elif self.rec_algorithm == 'NRTR': - postprocess_params = { - 'name': 'NRTRLabelDecode', - "character_dict_path": word_dict_path, - "use_space_char": use_space_char - } - elif self.rec_algorithm == "SAR": - postprocess_params = { - 'name': 'SARLabelDecode', - "character_dict_path": word_dict_path, - "use_space_char": use_space_char - } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors, self.config = \ create_predictor(args, cfg, 'rec') @@ -191,15 +166,6 @@ class TextRecognizer(object): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape - if self.rec_algorithm == 'NRTR': - img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) - # return padding_im - image_pil = Image.fromarray(np.uint8(img)) - img = image_pil.resize([100, 32], Image.ANTIALIAS) - img = np.array(img) - norm_img = np.expand_dims(img, -1) - norm_img = norm_img.transpose((2, 0, 1)) - return norm_img.astype(np.float32) / 128. - 1. assert imgC == img.shape[2] imgW = int((imgH * max_wh_ratio)) @@ -214,10 +180,6 @@ class TextRecognizer(object): resized_w = imgW else: resized_w = int(math.ceil(imgH * ratio)) - if self.rec_algorithm == 'RARE': - if resized_w > self.rec_image_shape[2]: - resized_w = self.rec_image_shape[2] - imgW = self.rec_image_shape[2] resized_image = cv2.resize(img, (resized_w, imgH)) resized_image = resized_image.astype('float32') resized_image = resized_image.transpose((2, 0, 1)) / 255 @@ -227,124 +189,6 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im - def resize_norm_img_svtr(self, img, image_shape): - - imgC, imgH, imgW = image_shape - resized_image = cv2.resize( - img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) - resized_image = resized_image.astype('float32') - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - return resized_image - - def resize_norm_img_srn(self, img, image_shape): - imgC, imgH, imgW = image_shape - - img_black = np.zeros((imgH, imgW)) - im_hei = img.shape[0] - im_wid = img.shape[1] - - if im_wid <= im_hei * 1: - img_new = cv2.resize(img, (imgH * 1, imgH)) - elif im_wid <= im_hei * 2: - img_new = cv2.resize(img, (imgH * 2, imgH)) - elif im_wid <= im_hei * 3: - img_new = cv2.resize(img, (imgH * 3, imgH)) - else: - img_new = cv2.resize(img, (imgW, imgH)) - - img_np = np.asarray(img_new) - img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) - img_black[:, 0:img_np.shape[1]] = img_np - img_black = img_black[:, :, np.newaxis] - - row, col, c = img_black.shape - c = 1 - - return np.reshape(img_black, (c, row, col)).astype(np.float32) - - def srn_other_inputs(self, image_shape, num_heads, max_text_length): - - imgC, imgH, imgW = image_shape - feature_dim = int((imgH / 8) * (imgW / 8)) - - encoder_word_pos = np.array(range(0, feature_dim)).reshape( - (feature_dim, 1)).astype('int64') - gsrm_word_pos = np.array(range(0, max_text_length)).reshape( - (max_text_length, 1)).astype('int64') - - gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) - gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias1 = np.tile( - gsrm_slf_attn_bias1, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] - - gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( - [-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias2 = np.tile( - gsrm_slf_attn_bias2, - [1, num_heads, 1, 1]).astype('float32') * [-1e9] - - encoder_word_pos = encoder_word_pos[np.newaxis, :] - gsrm_word_pos = gsrm_word_pos[np.newaxis, :] - - return [ - encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2 - ] - - def process_image_srn(self, img, image_shape, num_heads, max_text_length): - norm_img = self.resize_norm_img_srn(img, image_shape) - norm_img = norm_img[np.newaxis, :] - - [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ - self.srn_other_inputs(image_shape, num_heads, max_text_length) - - gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) - gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) - encoder_word_pos = encoder_word_pos.astype(np.int64) - gsrm_word_pos = gsrm_word_pos.astype(np.int64) - - return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, - gsrm_slf_attn_bias2) - - def resize_norm_img_sar(self, img, image_shape, - width_downsample_ratio=0.25): - imgC, imgH, imgW_min, imgW_max = image_shape - h = img.shape[0] - w = img.shape[1] - valid_ratio = 1.0 - # make sure new_width is an integral multiple of width_divisor. - width_divisor = int(1 / width_downsample_ratio) - # resize - ratio = w / float(h) - resize_w = math.ceil(imgH * ratio) - if resize_w % width_divisor != 0: - resize_w = round(resize_w / width_divisor) * width_divisor - if imgW_min is not None: - resize_w = max(imgW_min, resize_w) - if imgW_max is not None: - valid_ratio = min(1.0, 1.0 * resize_w / imgW_max) - resize_w = min(imgW_max, resize_w) - resized_image = cv2.resize(img, (resize_w, imgH)) - resized_image = resized_image.astype('float32') - # norm - if image_shape[0] == 1: - resized_image = resized_image / 255 - resized_image = resized_image[np.newaxis, :] - else: - resized_image = resized_image.transpose((2, 0, 1)) / 255 - resized_image -= 0.5 - resized_image /= 0.5 - resize_shape = resized_image.shape - padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32) - padding_im[:, :, 0:resize_w] = resized_image - pad_shape = padding_im.shape - - return padding_im, resize_shape, pad_shape, valid_ratio - def predict_text(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -367,115 +211,28 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - - if self.rec_algorithm == "SAR": - norm_img, _, _, valid_ratio = self.resize_norm_img_sar( - img_list[indices[ino]], self.rec_image_shape) - norm_img = norm_img[np.newaxis, :] - valid_ratio = np.expand_dims(valid_ratio, axis=0) - valid_ratios = [] - valid_ratios.append(valid_ratio) - norm_img_batch.append(norm_img) - elif self.rec_algorithm == "SRN": - norm_img = self.process_image_srn( - img_list[indices[ino]], self.rec_image_shape, 8, 25) - encoder_word_pos_list = [] - gsrm_word_pos_list = [] - gsrm_slf_attn_bias1_list = [] - gsrm_slf_attn_bias2_list = [] - encoder_word_pos_list.append(norm_img[1]) - gsrm_word_pos_list.append(norm_img[2]) - gsrm_slf_attn_bias1_list.append(norm_img[3]) - gsrm_slf_attn_bias2_list.append(norm_img[4]) - norm_img_batch.append(norm_img[0]) - elif self.rec_algorithm == "SVTR": - norm_img = self.resize_norm_img_svtr(img_list[indices[ino]], - self.rec_image_shape) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - else: - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - - if self.rec_algorithm == "SRN": - encoder_word_pos_list = np.concatenate(encoder_word_pos_list) - gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) - gsrm_slf_attn_bias1_list = np.concatenate( - gsrm_slf_attn_bias1_list) - gsrm_slf_attn_bias2_list = np.concatenate( - gsrm_slf_attn_bias2_list) - - inputs = [ - norm_img_batch, - encoder_word_pos_list, - gsrm_word_pos_list, - gsrm_slf_attn_bias1_list, - gsrm_slf_attn_bias2_list, - ] - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = {"predict": outputs[2]} - else: - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle( - input_names[i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = {"predict": outputs[2]} - elif self.rec_algorithm == "SAR": - valid_ratios = np.concatenate(valid_ratios) - inputs = [ - norm_img_batch, - valid_ratios, - ] - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = outputs[0] - else: - input_names = self.predictor.get_input_names() - for i in range(len(input_names)): - input_tensor = self.predictor.get_input_handle( - input_names[i]) - input_tensor.copy_from_cpu(inputs[i]) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = outputs[0] + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = norm_img_batch + outputs = self.predictor.run(self.output_tensors, input_dict) + preds = outputs[0] else: - if self.use_onnx: - input_dict = {} - input_dict[self.input_tensor.name] = norm_img_batch - outputs = self.predictor.run(self.output_tensors, - input_dict) - preds = outputs[0] + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if len(outputs) != 1: + preds = outputs else: - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - if len(outputs) != 1: - preds = outputs - else: - preds = outputs[0] + preds = outputs[0] rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/deploy/pphuman/ppvehicle/vehicle_plateutils.py b/deploy/pphuman/ppvehicle/vehicle_plateutils.py index c4c7cf2b2..431b64720 100644 --- a/deploy/pphuman/ppvehicle/vehicle_plateutils.py +++ b/deploy/pphuman/ppvehicle/vehicle_plateutils.py @@ -185,14 +185,9 @@ def create_predictor(args, cfg, mode): def get_output_tensors(cfg, mode, predictor): output_names = predictor.get_output_names() output_tensors = [] - if mode == "rec" and cfg['rec_algorithm'] in ["CRNN", "SVTR_LCNet"]: - output_name = 'softmax_0.tmp_0' - if output_name in output_names: - return [predictor.get_output_handle(output_name)] - else: - for output_name in output_names: - output_tensor = predictor.get_output_handle(output_name) - output_tensors.append(output_tensor) + output_name = 'softmax_0.tmp_0' + if output_name in output_names: + return [predictor.get_output_handle(output_name)] else: for output_name in output_names: output_tensor = predictor.get_output_handle(output_name) diff --git a/deploy/pphuman/ppvehicle/vehicleplate_postprocess.py b/deploy/pphuman/ppvehicle/vehicleplate_postprocess.py index 04a4b6216..7aa67a09a 100644 --- a/deploy/pphuman/ppvehicle/vehicleplate_postprocess.py +++ b/deploy/pphuman/ppvehicle/vehicleplate_postprocess.py @@ -23,16 +23,7 @@ import copy def build_post_process(config, global_config=None): - support_dict = [ - 'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'SRNLabelDecode', - 'DistillationCTCLabelDecode', 'TableLabelDecode', 'NRTRLabelDecode', - 'SARLabelDecode', 'SEEDLabelDecode', 'PRENLabelDecode', - 'DistillationSARLabelDecode' - ] - - if config['name'] == 'PSEPostProcess': - from .pse_postprocess import PSEPostProcess - support_dict.append('PSEPostProcess') + support_dict = ['DBPostProcess', 'CTCLabelDecode'] config = copy.deepcopy(config) module_name = config.pop('name') @@ -298,651 +289,3 @@ class CTCLabelDecode(BaseRecLabelDecode): def add_special_char(self, dict_character): dict_character = ['blank'] + dict_character return dict_character - - -class DistillationCTCLabelDecode(CTCLabelDecode): - """ - Convert - Convert between text-label and text-index - """ - - def __init__(self, - character_dict_path=None, - use_space_char=False, - model_name=["student"], - key=None, - multi_head=False, - **kwargs): - super(DistillationCTCLabelDecode, self).__init__(character_dict_path, - use_space_char) - if not isinstance(model_name, list): - model_name = [model_name] - self.model_name = model_name - - self.key = key - self.multi_head = multi_head - - def __call__(self, preds, label=None, *args, **kwargs): - output = dict() - for name in self.model_name: - pred = preds[name] - if self.key is not None: - pred = pred[self.key] - if self.multi_head and isinstance(pred, dict): - pred = pred['ctc'] - output[name] = super().__call__(pred, label=label, *args, **kwargs) - return output - - -class NRTRLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=True, **kwargs): - super(NRTRLabelDecode, self).__init__(character_dict_path, - use_space_char) - - def __call__(self, preds, label=None, *args, **kwargs): - - if len(preds) == 2: - preds_id = preds[0] - preds_prob = preds[1] - if isinstance(preds_id, paddle.Tensor): - preds_id = preds_id.numpy() - if isinstance(preds_prob, paddle.Tensor): - preds_prob = preds_prob.numpy() - if preds_id[0][0] == 2: - preds_idx = preds_id[:, 1:] - preds_prob = preds_prob[:, 1:] - else: - preds_idx = preds_id - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label[:, 1:]) - else: - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label[:, 1:]) - return text, label - - def add_special_char(self, dict_character): - dict_character = ['blank', '', '', ''] + dict_character - return dict_character - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == 3: # end - break - try: - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - except: - continue - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - result_list.append((text.lower(), np.mean(conf_list).tolist())) - return result_list - - -class AttnLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(AttnLabelDecode, self).__init__(character_dict_path, - use_space_char) - - def add_special_char(self, dict_character): - self.beg_str = "sos" - self.end_str = "eos" - dict_character = dict_character - dict_character = [self.beg_str] + dict_character + [self.end_str] - return dict_character - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - ignored_tokens = self.get_ignored_tokens() - [beg_idx, end_idx] = self.get_ignored_tokens() - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: - continue - if int(text_index[batch_idx][idx]) == int(end_idx): - break - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - result_list.append((text, np.mean(conf_list).tolist())) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - """ - text = self.decode(text) - if label is None: - return text - else: - label = self.decode(label, is_remove_duplicate=False) - return text, label - """ - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label, is_remove_duplicate=False) - return text, label - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return idx - - -class SEEDLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(SEEDLabelDecode, self).__init__(character_dict_path, - use_space_char) - - def add_special_char(self, dict_character): - self.padding_str = "padding" - self.end_str = "eos" - self.unknown = "unknown" - dict_character = dict_character + [ - self.end_str, self.padding_str, self.unknown - ] - return dict_character - - def get_ignored_tokens(self): - end_idx = self.get_beg_end_flag_idx("eos") - return [end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "sos": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "eos": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end - return idx - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - [end_idx] = self.get_ignored_tokens() - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if int(text_index[batch_idx][idx]) == int(end_idx): - break - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - result_list.append((text, np.mean(conf_list).tolist())) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - """ - text = self.decode(text) - if label is None: - return text - else: - label = self.decode(label, is_remove_duplicate=False) - return text, label - """ - preds_idx = preds["rec_pred"] - if isinstance(preds_idx, paddle.Tensor): - preds_idx = preds_idx.numpy() - if "rec_pred_scores" in preds: - preds_idx = preds["rec_pred"] - preds_prob = preds["rec_pred_scores"] - else: - preds_idx = preds["rec_pred"].argmax(axis=2) - preds_prob = preds["rec_pred"].max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - if label is None: - return text - label = self.decode(label, is_remove_duplicate=False) - return text, label - - -class SRNLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(SRNLabelDecode, self).__init__(character_dict_path, - use_space_char) - self.max_text_length = kwargs.get('max_text_length', 25) - - def __call__(self, preds, label=None, *args, **kwargs): - pred = preds['predict'] - char_num = len(self.character_str) + 2 - if isinstance(pred, paddle.Tensor): - pred = pred.numpy() - pred = np.reshape(pred, [-1, char_num]) - - preds_idx = np.argmax(pred, axis=1) - preds_prob = np.max(pred, axis=1) - - preds_idx = np.reshape(preds_idx, [-1, self.max_text_length]) - - preds_prob = np.reshape(preds_prob, [-1, self.max_text_length]) - - text = self.decode(preds_idx, preds_prob) - - if label is None: - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - return text - label = self.decode(label) - return text, label - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - ignored_tokens = self.get_ignored_tokens() - batch_size = len(text_index) - - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: - continue - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - - text = ''.join(char_list) - result_list.append((text, np.mean(conf_list).tolist())) - return result_list - - def add_special_char(self, dict_character): - dict_character = dict_character + [self.beg_str, self.end_str] - return dict_character - - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end): - if beg_or_end == "beg": - idx = np.array(self.dict[self.beg_str]) - elif beg_or_end == "end": - idx = np.array(self.dict[self.end_str]) - else: - assert False, "unsupport type %s in get_beg_end_flag_idx" \ - % beg_or_end - return idx - - -class TableLabelDecode(object): - """ """ - - def __init__(self, character_dict_path, **kwargs): - list_character, list_elem = self.load_char_elem_dict( - character_dict_path) - list_character = self.add_special_char(list_character) - list_elem = self.add_special_char(list_elem) - self.dict_character = {} - self.dict_idx_character = {} - for i, char in enumerate(list_character): - self.dict_idx_character[i] = char - self.dict_character[char] = i - self.dict_elem = {} - self.dict_idx_elem = {} - for i, elem in enumerate(list_elem): - self.dict_idx_elem[i] = elem - self.dict_elem[elem] = i - - def load_char_elem_dict(self, character_dict_path): - list_character = [] - list_elem = [] - with open(character_dict_path, "rb") as fin: - lines = fin.readlines() - substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split( - "\t") - character_num = int(substr[0]) - elem_num = int(substr[1]) - for cno in range(1, 1 + character_num): - character = lines[cno].decode('utf-8').strip("\n").strip("\r\n") - list_character.append(character) - for eno in range(1 + character_num, 1 + character_num + elem_num): - elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n") - list_elem.append(elem) - return list_character, list_elem - - def add_special_char(self, list_character): - self.beg_str = "sos" - self.end_str = "eos" - list_character = [self.beg_str] + list_character + [self.end_str] - return list_character - - def __call__(self, preds): - structure_probs = preds['structure_probs'] - loc_preds = preds['loc_preds'] - if isinstance(structure_probs, paddle.Tensor): - structure_probs = structure_probs.numpy() - if isinstance(loc_preds, paddle.Tensor): - loc_preds = loc_preds.numpy() - structure_idx = structure_probs.argmax(axis=2) - structure_probs = structure_probs.max(axis=2) - structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode( - structure_idx, structure_probs, 'elem') - res_html_code_list = [] - res_loc_list = [] - batch_num = len(structure_str) - for bno in range(batch_num): - res_loc = [] - for sno in range(len(structure_str[bno])): - text = structure_str[bno][sno] - if text in ['', ' 0 and tmp_elem_idx == end_idx: - break - if tmp_elem_idx in ignored_tokens: - continue - - char_list.append(current_dict[tmp_elem_idx]) - elem_pos_list.append(idx) - score_list.append(structure_probs[batch_idx, idx]) - elem_idx_list.append(tmp_elem_idx) - result_list.append(char_list) - result_pos_list.append(elem_pos_list) - result_score_list.append(score_list) - result_elem_idx_list.append(elem_idx_list) - return result_list, result_pos_list, result_score_list, result_elem_idx_list - - def get_ignored_tokens(self, char_or_elem): - beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem) - end_idx = self.get_beg_end_flag_idx("end", char_or_elem) - return [beg_idx, end_idx] - - def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): - if char_or_elem == "char": - if beg_or_end == "beg": - idx = self.dict_character[self.beg_str] - elif beg_or_end == "end": - idx = self.dict_character[self.end_str] - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ - % beg_or_end - elif char_or_elem == "elem": - if beg_or_end == "beg": - idx = self.dict_elem[self.beg_str] - elif beg_or_end == "end": - idx = self.dict_elem[self.end_str] - else: - assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ - % beg_or_end - else: - assert False, "Unsupport type %s in char_or_elem" \ - % char_or_elem - return idx - - -class SARLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(SARLabelDecode, self).__init__(character_dict_path, - use_space_char) - - self.rm_symbol = kwargs.get('rm_symbol', False) - - def add_special_char(self, dict_character): - beg_end_str = "" - unknown_str = "" - padding_str = "" - dict_character = dict_character + [unknown_str] - self.unknown_idx = len(dict_character) - 1 - dict_character = dict_character + [beg_end_str] - self.start_idx = len(dict_character) - 1 - self.end_idx = len(dict_character) - 1 - dict_character = dict_character + [padding_str] - self.padding_idx = len(dict_character) - 1 - return dict_character - - def decode(self, text_index, text_prob=None, is_remove_duplicate=False): - """ convert text-index into text-label. """ - result_list = [] - ignored_tokens = self.get_ignored_tokens() - - batch_size = len(text_index) - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] in ignored_tokens: - continue - if int(text_index[batch_idx][idx]) == int(self.end_idx): - if text_prob is None and idx == 0: - continue - else: - break - if is_remove_duplicate: - # only for predict - if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ - batch_idx][idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - text = ''.join(char_list) - if self.rm_symbol: - comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') - text = text.lower() - text = comp.sub('', text) - result_list.append((text, np.mean(conf_list).tolist())) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - if isinstance(preds, paddle.Tensor): - preds = preds.numpy() - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) - - if label is None: - return text - label = self.decode(label, is_remove_duplicate=False) - return text, label - - def get_ignored_tokens(self): - return [self.padding_idx] - - -class DistillationSARLabelDecode(SARLabelDecode): - """ - Convert - Convert between text-label and text-index - """ - - def __init__(self, - character_dict_path=None, - use_space_char=False, - model_name=["student"], - key=None, - multi_head=False, - **kwargs): - super(DistillationSARLabelDecode, self).__init__(character_dict_path, - use_space_char) - if not isinstance(model_name, list): - model_name = [model_name] - self.model_name = model_name - - self.key = key - self.multi_head = multi_head - - def __call__(self, preds, label=None, *args, **kwargs): - output = dict() - for name in self.model_name: - pred = preds[name] - if self.key is not None: - pred = pred[self.key] - if self.multi_head and isinstance(pred, dict): - pred = pred['sar'] - output[name] = super().__call__(pred, label=label, *args, **kwargs) - return output - - -class PRENLabelDecode(BaseRecLabelDecode): - """ Convert between text-label and text-index """ - - def __init__(self, character_dict_path=None, use_space_char=False, - **kwargs): - super(PRENLabelDecode, self).__init__(character_dict_path, - use_space_char) - - def add_special_char(self, dict_character): - padding_str = '' # 0 - end_str = '' # 1 - unknown_str = '' # 2 - - dict_character = [padding_str, end_str, unknown_str] + dict_character - self.padding_idx = 0 - self.end_idx = 1 - self.unknown_idx = 2 - - return dict_character - - def decode(self, text_index, text_prob=None): - """ convert text-index into text-label. """ - result_list = [] - batch_size = len(text_index) - - for batch_idx in range(batch_size): - char_list = [] - conf_list = [] - for idx in range(len(text_index[batch_idx])): - if text_index[batch_idx][idx] == self.end_idx: - break - if text_index[batch_idx][idx] in \ - [self.padding_idx, self.unknown_idx]: - continue - char_list.append(self.character[int(text_index[batch_idx][ - idx])]) - if text_prob is not None: - conf_list.append(text_prob[batch_idx][idx]) - else: - conf_list.append(1) - - text = ''.join(char_list) - if len(text) > 0: - result_list.append((text, np.mean(conf_list).tolist())) - else: - # here confidence of empty recog result is 1 - result_list.append(('', 1)) - return result_list - - def __call__(self, preds, label=None, *args, **kwargs): - preds = preds.numpy() - preds_idx = preds.argmax(axis=2) - preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob) - if label is None: - return text - label = self.decode(label) - return text, label -- GitLab