diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 235c62f90a618320f2bc8de622b4a397a504c42b..22548a3c98a8dbe0fb07a9a0e1b721dc1ea1298c 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -1,14 +1,17 @@ Global: - use_gpu: False + use_gpu: True epoch_num: 600 log_smooth_window: 20 - print_batch_step: 2 + print_batch_step: 10 save_model_dir: ./output/pg_r50_vd_tt/ - save_epoch_step: 1 - # evaluation is run every 5000 iterationss after the 4000th iteration + save_epoch_step: 10 + # evaluation is run every 0 iterationss after the 1000th iteration eval_batch_step: [ 0, 1000 ] - # if pretrained_model is saved in static mode, load_static_weights must set to True - load_static_weights: False + # 1. If pretrained_model is saved in static mode, such as classification pretrained model + # from static branch, load_static_weights must be set as True. + # 2. If you want to finetune the pretrained models we provide in the docs, + # you should set load_static_weights as False. + load_static_weights: True cal_metric_during_train: False pretrained_model: checkpoints: @@ -19,7 +22,7 @@ Global: Architecture: model_type: e2e - algorithm: PG + algorithm: PGNet Transform: Backbone: name: ResNet @@ -34,28 +37,16 @@ Architecture: Loss: name: PGLoss -#Optimizer: -# name: Adam -# beta1: 0.9 -# beta2: 0.999 -# lr: -# name: Cosine -# learning_rate: 0.001 -# warmup_epoch: 1 -# regularizer: -# name: 'L2' -# factor: 0 - Optimizer: - name: RMSProp + name: Adam + beta1: 0.9 + beta2: 0.999 lr: - name: Piecewise learning_rate: 0.001 - decay_epochs: [ 40, 80, 120, 160, 200 ] - values: [ 0.001, 0.00033, 0.0001, 0.000033, 0.00001 ] regularizer: name: 'L2' - factor: 0.00005 + factor: 0 + PostProcess: name: PGPostProcess @@ -65,45 +56,45 @@ PostProcess: Metric: name: E2EMetric + Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] main_indicator: f_score_e2e Train: dataset: name: PGDateSet - label_file_list: - ratio_list: - data_format: textnet # textnet/partvgg - Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + label_file_list: [./train_data/total_text/train/] + ratio_list: [1.0] + data_format: icdar transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - PGProcessTrain: batch_size: 14 - data_format: icdar - tcl_len: 64 min_crop_size: 24 min_text_size: 4 max_text_size: 512 + Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: shuffle: True drop_last: True - batch_size_per_card: 1 - num_workers: 8 + batch_size_per_card: 14 + num_workers: 16 Eval: dataset: - name: PGDateSet + name: PGDataSet data_dir: ./train_data/ - label_file_list: + label_file_list: [./train_data/total_text/test/] transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - E2ELabelEncode: - label_list: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + max_len: 50 - E2EResizeForTest: valid_set: totaltext max_side_len: 768 diff --git a/doc/doc_ch/e2e.md b/doc/doc_ch/e2e.md new file mode 100644 index 0000000000000000000000000000000000000000..a0695697e39345b391a9bb37114136ee8e5743dc --- /dev/null +++ b/doc/doc_ch/e2e.md @@ -0,0 +1,120 @@ +# 端到端文字识别 + +本节以partvgg/totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 + +## 数据准备 +支持两种不同的数据形式textnet / icdar ,分别为四点标注数据和十四点标注数据,十四点标注数据效果要比四点标注效果好 +###数据形式为textnet + +解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是: +``` +/PaddleOCR/train_data/part_vgg_synth/train/ + └─ image/ partvgg数据集的训练数据 + └─ train_annotation_info.txt partvgg数据集的测试标注 +``` + +提供的标注文件格式如下,中间用"\t"分隔: +``` +" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注 +119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why +``` +标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。 +最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** + + +###数据形式为icdar +解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是: +``` +/PaddleOCR/train_data/total_text/train/ + └─ rgb/ total_text数据集的训练数据 + └─ poly/ total_text数据集的测试标注 +``` + +提供的标注文件格式如下,中间用"\t"分隔: +``` +" 图像标注信息--十四点标注数据 图像标注信息--识别标注 +1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST +1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972 +``` +标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。 +最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** +如果您想在其他数据集上训练,可以按照上述形式构建标注文件。 + +## 快速启动训练 + +首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列, +您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 +```shell +cd PaddleOCR/ +下载ResNet50_vd的预训练模型 +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar + +# 解压预训练模型文件,以ResNet50_vd为例 +tar -xf ./pretrain_models/ResNet50_vd_ssld_pretrained.tar ./pretrain_models/ + +# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下: +./pretrain_models/ResNet50_vd_ssld_pretrained/ + └─ conv_last_bn_mean + └─ conv_last_bn_offset + └─ conv_last_bn_scale + └─ conv_last_bn_variance + └─ ...... + +``` + +#### 启动训练 + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +```shell +# 单机单卡训练 e2e 模型 +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \ + -o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True +# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID +python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml \ + -o Global.pretrain_weights=./pretrain_models/ResNet50_vd_ssld_pretrained/ Global.load_static_weights=True +``` + + +上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。 +有关配置文件的详细解释,请参考[链接](./config.md)。 + +您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001 +``` + +#### 断点训练 + +如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径: +```shell +python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model +``` + +**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。 + +## 指标评估 + +PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。 + +运行如下代码,根据配置文件`e2e_r50_vd_pg.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。 + +评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化 +训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。 +```shell +python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" +``` + + + +## 测试端到端效果 + +测试单张图像的端到端识别效果 +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false +``` + +测试文件夹下所有图像的端到端识别效果 +```shell +python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false +``` diff --git a/doc/imgs_results/e2e_res_img623_pg.jpg b/doc/imgs_results/e2e_res_img623_pg.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84fca124363353313750984b4cf64ce2c2cad70b Binary files /dev/null and b/doc/imgs_results/e2e_res_img623_pg.jpg differ diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index bcfbf489191b9f3622843faafeadfe2cf9fef803..26a2a8dcfef24f48ab1331ece0b69bea9959f2ea 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -34,7 +34,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.lmdb_dataset import LMDBDataSet -from ppocr.data.pgnet_dataset import PGDateSet +from ppocr.data.pgnet_dataset import PGDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -55,8 +55,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger, seed=None): config = copy.deepcopy(config) - - support_dict = ['SimpleDataSet', 'LMDBDateSet', 'PGDateSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 3ae22b40ff0df78c3b4205f51da60d5afd95ad3b..85aa8bb28a1b9c5d945b5d8cfa290975df1d7a48 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -35,9 +35,10 @@ class ClsLabelEncode(object): class E2ELabelEncode(object): - def __init__(self, label_list, **kwargs): - self.label_list = label_list - self.max_len = 50 + def __init__(self, Lexicon_Table, max_len, **kwargs): + self.Lexicon_Table = Lexicon_Table + self.max_len = max_len + self.pad_num = len(self.Lexicon_Table) def __call__(self, data): text_label_index_list, temp_text = [], [] @@ -46,9 +47,10 @@ class E2ELabelEncode(object): text = text.upper() temp_text = [] for c_ in text: - if c_ in self.label_list: - temp_text.append(self.label_list.index(c_)) - temp_text = temp_text + [36] * (self.max_len - len(temp_text)) + if c_ in self.Lexicon_Table: + temp_text.append(self.Lexicon_Table.index(c_)) + temp_text = temp_text + [self.pad_num] * (self.max_len - + len(temp_text)) text_label_index_list.append(temp_text) data['strs'] = np.array(text_label_index_list) return data diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index d4cdad28826c34b885392989aecb771c44e1ea61..9c48b09647527cf718113ea1b5df152ff7befa04 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -197,7 +197,6 @@ class DetResizeForTest(object): sys.exit(0) ratio_h = resize_h / float(h) ratio_w = resize_w / float(w) - # return img, np.array([h, w]) return img, [ratio_h, ratio_w] def resize_image_type2(self, img): @@ -206,7 +205,6 @@ class DetResizeForTest(object): resize_w = w resize_h = h - # Fix the longer side if resize_h > resize_w: ratio = float(self.resize_long) / resize_h else: @@ -245,10 +243,8 @@ class E2EResizeForTest(object): return data def resize_image_for_totaltext(self, im, max_side_len=512): - """ - """ - h, w, _ = im.shape + h, w, _ = im.shape resize_w = w resize_h = h ratio = 1.25 diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 60abf194fb1d1c7ffeac9ddc6b4e3ad7fbcc26c3..a496ed43b56972c6bf2feff8abc26e620abd643c 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,6 @@ import math import cv2 import numpy as np -import os __all__ = ['PGProcessTrain'] @@ -23,15 +22,11 @@ __all__ = ['PGProcessTrain'] class PGProcessTrain(object): def __init__(self, batch_size=14, - data_format='icdar', - tcl_len=64, min_crop_size=24, min_text_size=10, max_text_size=512, **kwargs): self.batch_size = batch_size - self.data_format = data_format - self.tcl_len = tcl_len self.min_crop_size = min_crop_size self.min_text_size = min_text_size self.max_text_size = max_text_size @@ -60,24 +55,22 @@ class PGProcessTrain(object): """ point_num = poly.shape[0] min_area_quad = np.zeros((4, 2), dtype=np.float32) - if True: - rect = cv2.minAreaRect(poly.astype( - np.int32)) # (center (x,y), (width, height), angle of rotation) - center_point = rect[0] - box = np.array(cv2.boxPoints(rect)) - - first_point_idx = 0 - min_dist = 1e4 - for i in range(4): - dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ - np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ - np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ - np.linalg.norm(box[(i + 3) % 4] - poly[-1]) - if dist < min_dist: - min_dist = dist - first_point_idx = i - for i in range(4): - min_area_quad[i] = box[(first_point_idx + i) % 4] + rect = cv2.minAreaRect(poly.astype( + np.int32)) # (center (x,y), (width, height), angle of rotation) + box = np.array(cv2.boxPoints(rect)) + + first_point_idx = 0 + min_dist = 1e4 + for i in range(4): + dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \ + np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \ + np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \ + np.linalg.norm(box[(i + 3) % 4] - poly[-1]) + if dist < min_dist: + min_dist = dist + first_point_idx = i + for i in range(4): + min_area_quad[i] = box[(first_point_idx + i) % 4] return min_area_quad @@ -235,8 +228,6 @@ class PGProcessTrain(object): ys, xs = np.where(tmp_image > 0) xy_text = np.array(list(zip(xs, ys)), dtype='float32') - # left_center_pt = np.array(key_point_xys[0]).reshape(1, 2) - # right_center_pt = np.array(key_point_xys[-1]).reshape(1, 2) left_center_pt = ( (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2) right_center_pt = ( @@ -317,16 +308,6 @@ class PGProcessTrain(object): average_height = max(sum(height_list) / len(height_list), 1.0) return average_height - def encode(self, text): - text_list = [] - for char in text: - if char not in self.dict: - continue - text_list.append([self.dict[char]]) - if len(text_list) == 0: - return None - return text_list - def generate_tcl_ctc_label(self, h, w, @@ -390,8 +371,6 @@ class PGProcessTrain(object): text_label = text_strs[poly_idx] text_label = self.prepare_text_label(text_label, self.Lexicon_Table) - # text = text.decode('utf-8') - # text_label_index_list = self.encode(text) text_label_index_list = [[self.Lexicon_Table.index(c_)] for c_ in text_label @@ -402,22 +381,18 @@ class PGProcessTrain(object): tcl_poly = self.poly2tcl(poly, tcl_ratio) tcl_quads = self.poly2quads(tcl_poly) poly_quads = self.poly2quads(poly) - # stcl map + stcl_quads, quad_index = self.shrink_poly_along_width( tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width, expand_height_ratio=1.0 / tcl_ratio) - # generate tcl map + cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) cv2.fillPoly(score_map_big, np.round(stcl_quads / ds_ratio).astype(np.int32), 1.0) - # generate tbo map - # tbo_tcl_poly = poly2tcl(poly, 0.5) - # tbo_tcl_quads = poly2quads(tbo_tcl_poly) - # for idx, quad in enumerate(tbo_tcl_quads): for idx, quad in enumerate(stcl_quads): quad_mask = np.zeros((h, w), dtype=np.float32) quad_mask = cv2.fillPoly( @@ -432,7 +407,6 @@ class PGProcessTrain(object): score_label_map_text_label_list.append(text_pos_list_) label_idx += 1 - # cv2.fillPoly(score_label_map, np.round(poly_quads[np.newaxis, :, :]).astype(np.int32), label_idx) cv2.fillPoly(score_label_map, np.round(poly_quads).astype(np.int32), label_idx) score_label_map_text_label_list.append(text_label_index_list) @@ -641,8 +615,6 @@ class PGProcessTrain(object): d = a1 * b2 - a2 * b1 if d == 0: - # print("line1", line1) - # print("line2", line2) print('Cross point does not exist') return np.array([0, 0], dtype=np.float32) else: diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index ed970d7ef89215c48c8320f1e8f44a7ba205bb72..2e0989fd29d7d808d8bf43c259dbb68dad0c9294 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ from .imaug import transform, create_operators import random -class PGDateSet(Dataset): +class PGDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): - super(PGDateSet, self).__init__() + super(PGDataSet, self).__init__() self.logger = logger self.seed = seed @@ -81,7 +81,9 @@ class PGDateSet(Dataset): """ info_list = im_fn.split('\t') img_path = '' - for ext in ['.jpg', '.png', '.jpeg', '.JPG']: + for ext in [ + 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG' + ]: if os.path.exists(os.path.join(img_dir, info_list[0] + ext)): img_path = os.path.join(img_dir, info_list[0] + ext) break @@ -111,11 +113,12 @@ class PGDateSet(Dataset): for idx, data_source in enumerate(file_list): image_files = [] if data_format == 'icdar': - image_files = [ - (data_source, x) - for x in os.listdir(os.path.join(data_source, 'rgb')) - if x.split('.')[-1] in ['jpg', 'png', 'jpeg', 'JPG'] - ] + image_files = [(data_source, x) for x in + os.listdir(os.path.join(data_source, 'rgb')) + if x.split('.')[-1] in [ + 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', + 'tiff', 'gif', 'JPG' + ]] elif data_format == 'textnet': with open(data_source) as f: image_files = [(data_source, x.strip()) diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py index 05480a9eba8742220e8523c04276b1d6c394fc79..be4614e70e41d2ac7962920844cf30327a8407a3 100644 --- a/ppocr/losses/e2e_pg_loss.py +++ b/ppocr/losses/e2e_pg_loss.py @@ -1,4 +1,4 @@ -# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,20 +21,12 @@ import paddle import numpy as np import copy -from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss +from .det_basic_loss import DiceLoss class PGLoss(nn.Layer): - """ - Differentiable Binarization (DB) Loss Function - args: - param (dict): the super paramter for DB Loss - """ - - def __init__(self, alpha=5, beta=10, eps=1e-6, **kwargs): + def __init__(self, eps=1e-6, **kwargs): super(PGLoss, self).__init__() - self.alpha = alpha - self.beta = beta self.dice_loss = DiceLoss(eps=eps) def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists): @@ -86,27 +78,30 @@ class PGLoss(nn.Layer): return pos_lists_, pos_masks_, label_lists_ def pre_process(self, label_list, pos_list, pos_mask): + max_len = 30 # the max texts in a single image + max_str_len = 50 # the max len in a single text + pad_num = 36 # padding num label_list = label_list.numpy() - b, h, w, c = label_list.shape + batch, _, _, _ = label_list.shape pos_list = pos_list.numpy() pos_mask = pos_mask.numpy() pos_list_t = [] pos_mask_t = [] label_list_t = [] - for i in range(b): - for j in range(30): + for i in range(batch): + for j in range(max_len): if pos_mask[i, j].any(): pos_list_t.append(pos_list[i][j]) pos_mask_t.append(pos_mask[i][j]) label_list_t.append(label_list[i][j]) pos_list, pos_mask, label_list = self.org_tcl_rois( - b, pos_list_t, pos_mask_t, label_list_t) + batch, pos_list_t, pos_mask_t, label_list_t) label = [] tt = [l.tolist() for l in label_list] - for i in range(64): + for i in range(batch): k = 0 - for j in range(50): - if tt[i][j][0] != 36: + for j in range(max_str_len): + if tt[i][j][0] != pad_num: k += 1 else: break diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 45248b917ee5a69998d59a439be738463172dea2..c6cd1db94370e966a56fb694265a83e46c5e9ee3 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,12 +22,9 @@ from ppocr.utils.e2e_metric.Deteval import * class E2EMetric(object): - def __init__(self, main_indicator='f_score_e2e', **kwargs): - self.label_list = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', - 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' - ] + def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs): + self.label_list = Lexicon_Table + self.max_index = len(self.label_list) self.main_indicator = main_indicator self.reset() @@ -40,12 +37,12 @@ class E2EMetric(object): for temp_list in temp_gt_strs_batch: t = "" for index in temp_list: - if index < 36: + if index < self.max_index: t += self.label_list[index] gt_strs_batch.append(t) for pred, gt_polyons, gt_strs, ignore_tags in zip( - preds, gt_polyons_batch, gt_strs_batch, ignore_tags_batch): + [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): # prepare gt gt_info_list = [{ 'points': gt_polyon, @@ -57,7 +54,7 @@ class E2EMetric(object): e2e_info_list = [{ 'points': det_polyon, 'text': pred_str - } for det_polyon, pred_str in zip(pred['points'], preds['strs'])] + } for det_polyon, pred_str in zip(pred['points'], pred['strs'])] result = get_socre(gt_info_list, e2e_info_list) self.results.append(result) diff --git a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py index 8e3697ecfe43fad7b38487d0147c7750f57ca993..97afd3460d03dc078b53064fb45b6fb6d3542df9 100644 --- a/ppocr/modeling/backbones/e2e_resnet_vd_pg.py +++ b/ppocr/modeling/backbones/e2e_resnet_vd_pg.py @@ -1,4 +1,4 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -62,8 +62,6 @@ class ConvBNLayer(nn.Layer): moving_variance_name=bn_name + '_variance') def forward(self, inputs): - # if self.is_vd_mode: - # inputs = self._pool2d_avg(inputs) y = self._conv(inputs) y = self._batch_norm(y) return y diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 41ead8e800063b1bf5ddc76fa06836eef7205191..106cdfa689680d8370c9ad2b4d51e0c5a8c74ba7 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -1,4 +1,4 @@ -# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -179,7 +179,7 @@ class PGHead(nn.Layer): name="conv_f_char{}".format(5)) self.conv3 = nn.Conv2D( in_channels=256, - out_channels=6625, + out_channels=37, kernel_size=3, stride=1, padding=1, diff --git a/ppocr/modeling/necks/pg_fpn.py b/ppocr/modeling/necks/pg_fpn.py index ba14c1b2ca0450e67b0c97917718a4870362690d..3f64539f790b55bb1f95adc8d3c78b84ca2fccc5 100644 --- a/ppocr/modeling/necks/pg_fpn.py +++ b/ppocr/modeling/necks/pg_fpn.py @@ -1,4 +1,4 @@ -# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,8 +60,6 @@ class ConvBNLayer(nn.Layer): use_global_stats=False) def forward(self, inputs): - # if self.is_vd_mode: - # inputs = self._pool2d_avg(inputs) y = self._conv(inputs) y = self._batch_norm(y) return y @@ -112,7 +110,6 @@ class PGFPN(nn.Layer): num_inputs = [2048, 2048, 1024, 512, 256] num_outputs = [256, 256, 192, 192, 128] self.out_channels = 128 - # print(in_channels) self.conv_bn_layer_1 = ConvBNLayer( in_channels=3, out_channels=32, diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 1b340b42fb1ac97ba2365021c901c50c618f4ca6..1f1ab60e0df044a9f731bbdb3a87ff89da5bdd99 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,14 +23,9 @@ __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) -import numpy as np -from .locality_aware_nms import nms_locality from ppocr.utils.e2e_utils.extract_textpoint import * -from ppocr.utils.e2e_utils.ski_thin import * from ppocr.utils.e2e_utils.visual import * import paddle -import cv2 -import time class PGPostProcess(object): @@ -115,7 +110,6 @@ class PGPostProcess(object): if len(yx_center_line) == 1: yx_center_line.append(yx_center_line[-1]) - # expand corresponding offset for total-text. offset_expand = 1.0 if self.valid_set == 'totaltext': offset_expand = 1.2 @@ -137,7 +131,6 @@ class PGPostProcess(object): [ratio_w, ratio_h]).reshape(-1, 2) point_pair_list.append(point_pair) - # for visualization all_point_list.append([ int(round(x * 4.0 / ratio_w)), int(round(y * 4.0 / ratio_h)) @@ -145,7 +138,6 @@ class PGPostProcess(object): all_point_pair_list.append(point_pair.round().astype(np.int32) .tolist()) - # ndarry: (x, 2) detected_poly, pair_length_info = point_pair2poly(point_pair_list) detected_poly = expand_poly_along_width( detected_poly, shrink_ratio_of_width=0.2) diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 8337e53934710e9d774d1c38afecbd9bc2c650aa..37fa5c00b4fa69bf8217b4840478f6c51752f673 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,42 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -try: # python2 - range = xrange -except Exception: - # python3 - range = range -""" -Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' -""" - -# if len(sys.argv) != 4: -# print('\n usage: test.py pred_dir gt_dir savefile') -# sys.exit() - def get_socre(gt_dict, pred_dict): - # allInputs = listdir(input_dir) allInputs = 1 - def input_reading_mod(pred_dict, input): + def input_reading_mod(pred_dict): """This helper reads input from txt files""" det = [] n = len(pred_dict) for i in range(n): points = pred_dict[i]['points'] text = pred_dict[i]['text'] - # for i in range(len(points)): point = ",".join(map(str, points.reshape(-1, ))) det.append([point, text]) return det - def gt_reading_mod(gt_dict, gt_id): + def gt_reading_mod(gt_dict): """This helper reads groundtruths from mat files""" - # gt_id = gt_id.split('.')[0] gt = [] n = len(gt_dict) for i in range(n): @@ -74,23 +59,12 @@ def get_socre(gt_dict, pred_dict): def detection_filtering(detections, groundtruths, threshold=0.5): for gt_id, gt in enumerate(groundtruths): - print - "liushanshan gt[1] = {}".format(gt[1]) - print - "liushanshan gt[2] = {}".format(gt[2]) - print - "liushanshan gt[3] = {}".format(gt[3]) - print - "liushanshan gt[4] = {}".format(gt[4]) - print - "liushanshan gt[5] = {}".format(gt[5]) if (gt[5] == '#') and (gt[1].shape[1] > 1): gt_x = list(map(int, np.squeeze(gt[1]))) gt_y = list(map(int, np.squeeze(gt[3]))) for det_id, detection in enumerate(detections): detection_orig = detection detection = [float(x) for x in detection[0].split(',')] - # detection = detection.split(',') detection = list(map(int, detection)) det_x = detection[0::2] det_y = detection[1::2] @@ -105,18 +79,10 @@ def get_socre(gt_dict, pred_dict): """ sigma = inter_area / gt_area """ - # print(area_of_intersection(det_x, det_y, gt_x, gt_y)) return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(gt_x, gt_y)), 2) def tau_calculation(det_x, det_y, gt_x, gt_y): - """ - tau = inter_area / det_area - """ - # print "liushanshan det_x {}".format(det_x) - # print "liushanshan det_y {}".format(det_y) - # print "liushanshan area {}".format(area(det_x, det_y)) - # print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2)) if area(det_x, det_y) == 0.0: return 0 return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / @@ -141,10 +107,8 @@ def get_socre(gt_dict, pred_dict): input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ and (input_id != 'Deteval_result_non_curved.txt'): - print(input_id) - detections = input_reading_mod(pred_dict, input_id) - # print "liushanshan detections = {}".format(detections) - groundtruths = gt_reading_mod(gt_dict, input_id) + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dict) detections = detection_filtering( detections, groundtruths) # filters detections overlapping with DC area @@ -187,10 +151,6 @@ def get_socre(gt_dict, pred_dict): global_tau.append(local_tau_table) global_pred_str.append(local_pred_str) global_gt_str.append(local_gt_str) - print - "liushanshan global_pred_str = {}".format(global_pred_str) - print - "liushanshan global_gt_str = {}".format(global_gt_str) global_accumulative_recall = 0 global_accumulative_precision = 0 @@ -236,17 +196,11 @@ def get_socre(gt_dict, pred_dict): gt_flag[0, gt_id] = 1 matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) # recg start - print - "liushanshan one to one det_id = {}".format(matched_det_id) - print - "liushanshan one to one gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ 0]] - print - "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) + if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -290,20 +244,10 @@ def get_socre(gt_dict, pred_dict): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][ qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) + if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -315,19 +259,11 @@ def get_socre(gt_dict, pred_dict): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) + gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][ qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) + if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -377,25 +313,14 @@ def get_socre(gt_dict, pred_dict): gt_flag[0, qualified_sigma_candidates] = 1 det_flag[0, det_id] = 1 # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): ele_gt_id = qualified_sigma_candidates[0].tolist()[ idx] - if not global_gt_str[idy].has_key(ele_gt_id): + if ele_gt_id not in global_gt_str[idy]: continue gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 break @@ -409,24 +334,14 @@ def get_socre(gt_dict, pred_dict): det_flag[0, det_id] = 1 gt_flag[0, qualified_sigma_candidates] = 1 # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) + pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if not global_gt_str[idy].has_key(ele_gt_id): + if ele_gt_id not in global_gt_str[idy]: continue gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 break @@ -434,9 +349,6 @@ def get_socre(gt_dict, pred_dict): if pred_str_cur.lower() == gt_str_cur.lower(): hit_str_num += 1 break - else: - print - 'no match' # recg end global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k @@ -448,7 +360,6 @@ def get_socre(gt_dict, pred_dict): single_data = {} for idx in range(len(global_sigma)): - # print(allInputs[idx]) local_sigma_table = global_sigma[idx] local_tau_table = global_tau[idx] @@ -504,8 +415,6 @@ def get_socre(gt_dict, pred_dict): except ZeroDivisionError: local_f_score = 0 - # temp = ('%s: Recall=%.4f, Precision=%.4f, f_score=%.4f\n' % ( - # allInputs[idx], local_recall, local_precision, local_f_score)) single_data['sigma'] = global_sigma single_data['global_tau'] = global_tau single_data['global_pred_str'] = global_pred_str @@ -575,17 +484,9 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) # recg start - print - "liushanshan one to one det_id = {}".format(matched_det_id) - print - "liushanshan one to one gt_id = {}".format(gt_id) gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ 0]] - print - "liushanshan one to one gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to one pred_str_cur = {}".format(pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -629,20 +530,9 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][ qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -654,19 +544,9 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - print - "liushanshan one to many det_id = {}".format( - qualified_tau_candidates) - print - "liushanshan one to many gt_id = {}".format(gt_id) gt_str_cur = global_gt_str[idy][gt_id] pred_str_cur = global_pred_str[idy][ qualified_tau_candidates[0].tolist()[0]] - print - "liushanshan one to many gt_str_cur = {}".format(gt_str_cur) - print - "liushanshan one to many pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 else: @@ -716,11 +596,6 @@ def combine_results(all_data): gt_flag[0, qualified_sigma_candidates] = 1 det_flag[0, det_id] = 1 # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): @@ -729,12 +604,6 @@ def combine_results(all_data): if ele_gt_id not in global_gt_str[idy]: continue gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 break @@ -748,24 +617,13 @@ def combine_results(all_data): det_flag[0, det_id] = 1 gt_flag[0, qualified_sigma_candidates] = 1 # recg start - print - "liushanshan many to one det_id = {}".format(det_id) - print - "liushanshan many to one gt_id = {}".format( - qualified_sigma_candidates) pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if not global_gt_str[idy].has_key(ele_gt_id): + if ele_gt_id not in global_gt_str[idy]: continue gt_str_cur = global_gt_str[idy][ele_gt_id] - print - "liushanshan many to one gt_str_cur = {}".format( - gt_str_cur) - print - "liushanshan many to one pred_str_cur = {}".format( - pred_str_cur) if pred_str_cur == gt_str_cur: hit_str_num += 1 break @@ -773,9 +631,6 @@ def combine_results(all_data): if pred_str_cur.lower() == gt_str_cur.lower(): hit_str_num += 1 break - else: - print - 'no match' # recg end global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py index 1665c7ef6009319fe9d1a1313ded7063c1143926..2d793aa98ebef3835c83efa190ff9dee204771f4 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint.py +++ b/ppocr/utils/e2e_utils/extract_textpoint.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,14 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import cv2 -import time import math import numpy as np from itertools import groupby -from ppocr.utils.e2e_utils.ski_thin import thin +from skimage.morphology._skeletonize import thin def softmax(logits): @@ -518,28 +516,6 @@ def generate_pivot_list_tt_inference(p_score, continue pos_list_sorted = sort_and_expand_with_direction_v2( pos_list, f_direction, p_tcl_map) - # pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) all_pos_yxs.append(pos_list_sorted_with_id) return all_pos_yxs - - -if __name__ == '__main__': - np.random.seed(0) - import time - - logits_map = np.random.random([10, 20, 33]) - # a list of [x, y] - instance_gather_info_1 = [(2, 3), (2, 4), (3, 5)] - instance_gather_info_2 = [(15, 6), (15, 7), (18, 8)] - instance_gather_info_3 = [(8, 8), (8, 8), (8, 8)] - gather_info_list = [ - instance_gather_info_1, instance_gather_info_2, instance_gather_info_3 - ] - - time0 = time.time() - res = ctc_decoder_for_image( - gather_info_list, logits_map, keep_blank_in_idxs=True) - print(res) - print('cost {}'.format(time.time() - time0)) - print('--' * 20) diff --git a/ppocr/utils/e2e_utils/ski_thin.py b/ppocr/utils/e2e_utils/ski_thin.py deleted file mode 100644 index 6b1e5c78f2bb51da1e65ff8b258f90932e25b651..0000000000000000000000000000000000000000 --- a/ppocr/utils/e2e_utils/ski_thin.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -from scipy import ndimage as ndi - -G123_LUT = np.array( - [ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, - 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, - 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, - 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0 - ], - dtype=np.bool) - -G123P_LUT = np.array( - [ - 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 - ], - dtype=np.bool) - - -def thin(image, max_iter=None): - """ - Perform morphological thinning of a binary image. - Parameters - ---------- - image : binary (M, N) ndarray - The image to be thinned. - max_iter : int, number of iterations, optional - Regardless of the value of this parameter, the thinned image - is returned immediately if an iteration produces no change. - If this parameter is specified it thus sets an upper bound on - the number of iterations performed. - Returns - ------- - out : ndarray of bool - Thinned image. - See also - -------- - skeletonize, medial_axis - Notes - ----- - This algorithm [1]_ works by making multiple passes over the image, - removing pixels matching a set of criteria designed to thin - connected regions while preserving eight-connected components and - 2 x 2 squares [2]_. In each of the two sub-iterations the algorithm - correlates the intermediate skeleton image with a neighborhood mask, - then looks up each neighborhood in a lookup table indicating whether - the central pixel should be deleted in that sub-iteration. - References - ---------- - .. [1] Z. Guo and R. W. Hall, "Parallel thinning with - two-subiteration algorithms," Comm. ACM, vol. 32, no. 3, - pp. 359-373, 1989. :DOI:`10.1145/62065.62074` - .. [2] Lam, L., Seong-Whan Lee, and Ching Y. Suen, "Thinning - Methodologies-A Comprehensive Survey," IEEE Transactions on - Pattern Analysis and Machine Intelligence, Vol 14, No. 9, - p. 879, 1992. :DOI:`10.1109/34.161346` - Examples - -------- - >>> square = np.zeros((7, 7), dtype=np.uint8) - >>> square[1:-1, 2:-2] = 1 - >>> square[0, 1] = 1 - >>> square - array([[0, 1, 0, 0, 0, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 1, 1, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0]], dtype=uint8) - >>> skel = thin(square) - >>> skel.astype(np.uint8) - array([[0, 1, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0]], dtype=uint8) - """ - # convert image to uint8 with values in {0, 1} - skel = np.asanyarray(image, dtype=bool).astype(np.uint8) - - # neighborhood mask - mask = np.array([[8, 4, 2], [16, 0, 1], [32, 64, 128]], dtype=np.uint8) - - # iterate until convergence, up to the iteration limit - max_iter = max_iter or np.inf - n_iter = 0 - n_pts_old, n_pts_new = np.inf, np.sum(skel) - while n_pts_old != n_pts_new and n_iter < max_iter: - n_pts_old = n_pts_new - - # perform the two "subiterations" described in the paper - for lut in [G123_LUT, G123P_LUT]: - # correlate image with neighborhood mask - N = ndi.correlate(skel, mask, mode='constant') - # take deletion decision from this subiteration's LUT - D = np.take(lut, N) - # perform deletion - skel[D] = 0 - - n_pts_new = np.sum(skel) # count points after thinning - n_iter += 1 - - return skel.astype(np.bool) diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py index 6be2107f0ad8bc57590dc68ff21cbf481dc8f2a0..6f8a429ef0fd85e14413bc1429e13a6ed81fc5f4 100644 --- a/ppocr/utils/e2e_utils/visual.py +++ b/ppocr/utils/e2e_utils/visual.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ def resize_image(im, max_side_len=512): def resize_image_min(im, max_side_len=512): """ """ - print('--> Using resize_image_min') + # print('--> Using resize_image_min') h, w, _ = im.shape resize_w = w diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 9dcde8a957346412d6296cb492bfa93feb8942c6..c40b8e02341afd8d1204aa5335bb7b0963e5899a 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -45,8 +45,14 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name): for box, str in zip(dt_boxes, strs): box = box.astype(np.int32).reshape((-1, 1, 2)) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - cv2.putText(src_im, str, org=(int(box[0, 0, 0]), int(box[0, 0, 1])), - fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.7, color=(0, 255, 0), thickness=1) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) save_det_path = os.path.dirname(config['Global'][ 'save_res_path']) + "/e2e_results/" if not os.path.exists(save_det_path): @@ -55,6 +61,7 @@ def draw_e2e_res(dt_boxes, strs, config, img, img_name): cv2.imwrite(save_path, src_im) logger.info("The e2e Image saved in {}".format(save_path)) + def main(): global_config = config['Global'] @@ -111,4 +118,4 @@ def main(): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() - main() \ No newline at end of file + main() diff --git a/tools/program.py b/tools/program.py index f0e2aa0812fe67b122fba2811e0588724d386e34..28e541db99e33876c5dd330b4448ec84d2e236a7 100755 --- a/tools/program.py +++ b/tools/program.py @@ -375,7 +375,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PG' + 'CLS', 'PGNet' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'