diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 22548a3c98a8dbe0fb07a9a0e1b721dc1ea1298c..08c485a638759e6436bd1613fff81fa14c8a6db8 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -3,7 +3,7 @@ Global: epoch_num: 600 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/pg_r50_vd_tt/ + save_model_dir: ./output/pgnet_r50_vd_totaltext/ save_epoch_step: 10 # evaluation is run every 0 iterationss after the 1000th iteration eval_batch_step: [ 0, 1000 ] @@ -18,7 +18,11 @@ Global: save_inference_dir: use_visualdl: False infer_img: - save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt + valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words + save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt + character_dict_path: ppocr/utils/pgnet_dict.txt + character_type: EN + max_text_length: 50 Architecture: model_type: e2e @@ -51,30 +55,26 @@ Optimizer: PostProcess: name: PGPostProcess score_thresh: 0.8 - cover_thresh: 0.1 - nms_thresh: 0.2 - Metric: name: E2EMetric - Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] + character_dict_path: ppocr/utils/pgnet_dict.txt main_indicator: f_score_e2e Train: dataset: - name: PGDateSet - label_file_list: [./train_data/total_text/train/] + name: PGDataSet + label_file_list: [.././train_data/total_text/train/] ratio_list: [1.0] - data_format: icdar + data_format: icdar #two data format: icdar/textnet transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - PGProcessTrain: - batch_size: 14 + batch_size: 14 # same as loader: batch_size_per_card min_crop_size: 24 min_text_size: 4 max_text_size: 512 - Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: @@ -93,10 +93,7 @@ Eval: img_mode: BGR channel_first: False - E2ELabelEncode: - Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ] - max_len: 50 - E2EResizeForTest: - valid_set: totaltext max_side_len: 768 - NormalizeImage: scale: 1./255. diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index 7968b355ea936d465b3c173c0fcdb3e08f12f16e..40f4d8c5119fb4be72573dd6a1f99ca59aeaf7aa 100755 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型) - [一、训练模型转inference模型](#训练模型转inference模型) - [检测模型转inference模型](#检测模型转inference模型) - [识别模型转inference模型](#识别模型转inference模型) - - [方向分类模型转inference模型](#方向分类模型转inference模型) + - [方向分类模型转inference模型](#方向分类模型转inference模型) + - [端到端模型转inference模型](#端到端模型转inference模型) - [二、文本检测模型推理](#文本检测模型推理) - [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理) @@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型) - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - [5. 多语言模型的推理](#多语言模型的推理) -- [四、方向分类模型推理](#方向识别模型推理) +- [四、端到端模型推理](#端到端模型推理) + - [1. PGNet端到端模型推理](#SAST文本检测模型推理) + +- [五、方向分类模型推理](#方向识别模型推理) - [1. 方向分类模型推理](#方向分类模型推理) -- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) +- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理) - [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理) - [2. 其他模型推理](#其他模型推理) @@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo ├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略 └── inference.pdmodel # 分类inference模型的program文件 ``` + +### 端到端模型转inference模型 + +下载端到端模型: +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/ +``` + +端到端模型转inference模型与检测的方式相同,如下: +``` +# -c 后面设置训练算法的yml配置文件 +# -o 配置可选参数 +# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。 +# Global.load_static_weights 参数需要设置为 False。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/ +``` + +转换成功后,在目录下有三个文件: +``` +/inference/e2e/ + ├── inference.pdiparams # 分类inference模型的参数文件 + ├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略 + └── inference.pdmodel # 分类inference模型的program文件 +``` ## 二、文本检测模型推理 @@ -332,8 +362,45 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" - Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904) ``` + +## 四、端到端模型推理 + +端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。 + +### 1. PGNet端到端模型推理 +#### (1). 四边形文本检测模型(ICDAR2015) +首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换: +``` +python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e +``` +**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/" +``` +可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: + +![](../imgs_results/det_res_img_10_sast.jpg) + +#### (2). 弯曲文本检测模型(Total-Text) +首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换: + +``` +python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt +``` + +**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令: +``` +python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True +``` +可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: + +![](../imgs_results/e2e_res_img623_pg.jpg) + +**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。 + + -## 四、方向分类模型推理 +## 五、方向分类模型推理 下面将介绍方向分类模型推理。 @@ -358,7 +425,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982] ``` -## 五、文本检测、方向分类和文字识别串联推理 +## 六、文本检测、方向分类和文字识别串联推理 ### 1. 超轻量中文OCR模型推理 diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 26a2a8dcfef24f48ab1331ece0b69bea9959f2ea..728b8317f54687ee76b519cba18f4d7807493821 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -73,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None): else: use_shared_memory = True if mode == "Train": - #Distribute data to multiple cards + # Distribute data to multiple cards batch_sampler = DistributedBatchSampler( dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) else: - #Distribute data to single card + # Distribute data to single card batch_sampler = BatchSampler( dataset=dataset, batch_size=batch_size, diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 85aa8bb28a1b9c5d945b5d8cfa290975df1d7a48..6c2fc8e4bbfcedb5150dd7baf13db267d1d74aa2 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -34,28 +34,6 @@ class ClsLabelEncode(object): return data -class E2ELabelEncode(object): - def __init__(self, Lexicon_Table, max_len, **kwargs): - self.Lexicon_Table = Lexicon_Table - self.max_len = max_len - self.pad_num = len(self.Lexicon_Table) - - def __call__(self, data): - text_label_index_list, temp_text = [], [] - texts = data['strs'] - for text in texts: - text = text.upper() - temp_text = [] - for c_ in text: - if c_ in self.Lexicon_Table: - temp_text.append(self.Lexicon_Table.index(c_)) - temp_text = temp_text + [self.pad_num] * (self.max_len - - len(temp_text)) - text_label_index_list.append(temp_text) - data['strs'] = np.array(text_label_index_list) - return data - - class DetLabelEncode(object): def __init__(self, **kwargs): pass @@ -209,6 +187,32 @@ class CTCLabelEncode(BaseRecLabelEncode): return dict_character +class E2ELabelEncode(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN', + use_space_char=False, + **kwargs): + super(E2ELabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def __call__(self, data): + texts = data['strs'] + temp_texts = [] + for text in texts: + text = text.upper() + text = self.encode(text) + if text is None: + return None + text = text + [36] * (self.max_text_len - len(text) + ) # use 36 to pad + temp_texts.append(text) + data['strs'] = np.array(temp_texts) + return data + + class AttnLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index a496ed43b56972c6bf2feff8abc26e620abd643c..58837d7bfa3643babf0dd66951d6b15c5e32865e 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -21,6 +21,7 @@ __all__ = ['PGProcessTrain'] class PGProcessTrain(object): def __init__(self, + character_dict_path, batch_size=14, min_crop_size=24, min_text_size=10, @@ -30,13 +31,19 @@ class PGProcessTrain(object): self.min_crop_size = min_crop_size self.min_text_size = min_text_size self.max_text_size = max_text_size - self.Lexicon_Table = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', - 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' - ] + self.Lexicon_Table = self.get_dict(character_dict_path) self.img_id = 0 + def get_dict(self, character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + def quad_area(self, poly): """ compute area of a polygon @@ -853,7 +860,7 @@ class PGProcessTrain(object): for i in range(len(label_list)): label_list[i] = np.array(label_list[i]) - if len(pos_list) <= 0 or len(pos_list) > 30: + if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本 return None for __ in range(30 - len(pos_list), 0, -1): pos_list.append(pos_list_temp) diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index c6cd1db94370e966a56fb694265a83e46c5e9ee3..04b73e0c4652c263d59380a0feff1f29da6c6817 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -19,11 +19,15 @@ from __future__ import print_function __all__ = ['E2EMetric'] from ppocr.utils.e2e_metric.Deteval import * +from ppocr.utils.e2e_utils.extract_textpoint import * class E2EMetric(object): - def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs): - self.label_list = Lexicon_Table + def __init__(self, + character_dict_path, + main_indicator='f_score_e2e', + **kwargs): + self.label_list = get_dict(character_dict_path) self.max_index = len(self.label_list) self.main_indicator = main_indicator self.reset() diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 106cdfa689680d8370c9ad2b4d51e0c5a8c74ba7..a3bc39aa6711f79f172e86912c42019d92543ed4 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -228,11 +228,11 @@ class PGHead(nn.Layer): f_score = self.conv1(f_score) f_score = F.sigmoid(f_score) - # f_boder - f_boder = self.conv_f_boder1(x) - f_boder = self.conv_f_boder2(f_boder) - f_boder = self.conv_f_boder3(f_boder) - f_boder = self.conv2(f_boder) + # f_border + f_border = self.conv_f_boder1(x) + f_border = self.conv_f_boder2(f_border) + f_border = self.conv_f_boder3(f_border) + f_border = self.conv2(f_border) f_char = self.conv_f_char1(x) f_char = self.conv_f_char2(f_char) @@ -246,4 +246,9 @@ class PGHead(nn.Layer): f_direction = self.conv_f_direc3(f_direction) f_direction = self.conv4(f_direction) - return f_score, f_boder, f_direction, f_char + predicts = {} + predicts['f_score'] = f_score + predicts['f_border'] = f_border + predicts['f_char'] = f_char + predicts['f_direction'] = f_direction + return predicts diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 1f1ab60e0df044a9f731bbdb3a87ff89da5bdd99..6d1b7d7a106c76fcf7104abe432f99588a4043eb 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -30,30 +30,14 @@ import paddle class PGPostProcess(object): """ - The post process for SAST. + The post process for PGNet. """ - def __init__(self, - score_thresh=0.5, - nms_thresh=0.2, - sample_pts_num=2, - shrink_ratio_of_width=0.3, - expand_scale=1.0, - tcl_map_thresh=0.5, - **kwargs): - self.result_path = "" - self.valid_set = 'totaltext' - self.Lexicon_Table = [ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', - 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', - 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' - ] + def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs): + + self.Lexicon_Table = get_dict(character_dict_path) + self.valid_set = valid_set self.score_thresh = score_thresh - self.nms_thresh = nms_thresh - self.sample_pts_num = sample_pts_num - self.shrink_ratio_of_width = shrink_ratio_of_width - self.expand_scale = expand_scale - self.tcl_map_thresh = tcl_map_thresh # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False @@ -61,16 +45,23 @@ class PGPostProcess(object): self.is_python35 = True def __call__(self, outs_dict, shape_list): - p_score, p_border, p_direction, p_char = outs_dict[:4] - p_score = p_score[0].numpy() - p_border = p_border[0].numpy() - p_direction = p_direction[0].numpy() - p_char = p_char[0].numpy() - src_h, src_w, ratio_h, ratio_w = shape_list[0] - if self.valid_set != 'totaltext': - is_curved = False + p_score = outs_dict['f_score'] + p_border = outs_dict['f_border'] + p_char = outs_dict['f_char'] + p_direction = outs_dict['f_direction'] + if isinstance(p_score, paddle.Tensor): + p_score = p_score[0].numpy() + p_border = p_border[0].numpy() + p_direction = p_direction[0].numpy() + p_char = p_char[0].numpy() else: - is_curved = True + p_score = p_score[0] + p_border = p_border[0] + p_direction = p_direction[0] + p_char = p_char[0] + + src_h, src_w, ratio_h, ratio_w = shape_list[0] + is_curved = self.valid_set == "totaltext" instance_yxs_list = generate_pivot_list( p_score, p_char, diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py index c78e2a1e5f81850a7e3d7f1bb3e825aba0573cc0..81c9ad70675bb37a95968283b6dc6f42f709df27 100755 --- a/ppocr/utils/e2e_metric/polygon_fast.py +++ b/ppocr/utils/e2e_metric/polygon_fast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py index 2d793aa98ebef3835c83efa190ff9dee204771f4..5355280946c0eadc8bc097e5409f755e8a390e5a 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint.py +++ b/ppocr/utils/e2e_utils/extract_textpoint.py @@ -24,6 +24,17 @@ from itertools import groupby from skimage.morphology._skeletonize import thin +def get_dict(character_dict_path): + character_str = "" + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + for line in lines: + line = line.decode('utf-8').strip("\n").strip("\r\n") + character_str += line + dict_character = list(character_str) + return dict_character + + def softmax(logits): """ logits: N x d @@ -164,7 +175,6 @@ def sort_and_expand_with_direction(pos_list, f_direction): h, w, _ = f_direction.shape sorted_list, point_direction = sort_with_direction(pos_list, f_direction) - # expand along point_num = len(sorted_list) sub_direction_len = max(point_num // 3, 2) left_direction = point_direction[:sub_direction_len, :] @@ -207,7 +217,6 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): h, w, _ = f_direction.shape sorted_list, point_direction = sort_with_direction(pos_list, f_direction) - # expand along point_num = len(sorted_list) sub_direction_len = max(point_num // 3, 2) left_direction = point_direction[:sub_direction_len, :] @@ -268,7 +277,6 @@ def generate_pivot_list_curved(p_score, instance_count, instance_label_map = cv2.connectedComponents( skeleton_map.astype(np.uint8), connectivity=8) - # get TCL Instance all_pos_yxs = [] center_pos_yxs = [] end_points_yxs = [] @@ -279,7 +287,6 @@ def generate_pivot_list_curved(p_score, ys, xs = np.where(instance_label_map == instance_id) pos_list = list(zip(ys, xs)) - ### FIX-ME, eliminate outlier if len(pos_list) < 3: continue @@ -290,7 +297,6 @@ def generate_pivot_list_curved(p_score, pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) all_pos_yxs.append(pos_list_sorted) - # use decoder to filter backgroud points. p_char_maps = p_char_maps.transpose([1, 2, 0]) decode_res = ctc_decoder_for_image( all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) @@ -335,11 +341,9 @@ def generate_pivot_list_horizontal(p_score, ys, xs = np.where(instance_label_map == instance_id) pos_list = list(zip(ys, xs)) - ### FIX-ME, eliminate outlier if len(pos_list) < 5: continue - # add rule here main_direction = extract_main_direction(pos_list, f_direction) # y x reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x @@ -370,7 +374,6 @@ def generate_pivot_list_horizontal(p_score, f_direction) all_pos_yxs.append(pos_list_sorted) - # use decoder to filter backgroud points. p_char_maps = p_char_maps.transpose([1, 2, 0]) decode_res = ctc_decoder_for_image( all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) @@ -417,7 +420,6 @@ def generate_pivot_list(p_score, image_id=image_id) -# for refine module def extract_main_direction(pos_list, f_direction): """ f_direction: h x w x 2 @@ -504,14 +506,12 @@ def generate_pivot_list_tt_inference(p_score, instance_count, instance_label_map = cv2.connectedComponents( skeleton_map.astype(np.uint8), connectivity=8) - # get TCL Instance all_pos_yxs = [] if instance_count > 0: for instance_id in range(1, instance_count): pos_list = [] ys, xs = np.where(instance_label_map == instance_id) pos_list = list(zip(ys, xs)) - ### FIX-ME, eliminate outlier if len(pos_list) < 3: continue pos_list_sorted = sort_and_expand_with_direction_v2( diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py index 6f8a429ef0fd85e14413bc1429e13a6ed81fc5f4..e6e4fd0667dbf4a42dbc0fd9bf26e6fd91be0d82 100644 --- a/ppocr/utils/e2e_utils/visual.py +++ b/ppocr/utils/e2e_utils/visual.py @@ -28,7 +28,6 @@ def resize_image(im, max_side_len=512): resize_w = w resize_h = h - # Fix the longer side if resize_h > resize_w: ratio = float(max_side_len) / resize_h else: @@ -50,13 +49,11 @@ def resize_image(im, max_side_len=512): def resize_image_min(im, max_side_len=512): """ """ - # print('--> Using resize_image_min') h, w, _ = im.shape resize_w = w resize_h = h - # Fix the longer side if resize_h < resize_w: ratio = float(max_side_len) / resize_h else: @@ -84,12 +81,7 @@ def resize_image_for_totaltext(im, max_side_len=512): ratio = 1.25 if h * ratio > max_side_len: ratio = float(max_side_len) / resize_h - # Fix the longer side - # if resize_h > resize_w: - # ratio = float(max_side_len) / resize_h - # else: - # ratio = float(max_side_len) / resize_w - ### + resize_h = int(resize_h * ratio) resize_w = int(resize_w * ratio) @@ -114,7 +106,6 @@ def point_pair2poly(point_pair_list): pair_info = (pair_length_list.max(), pair_length_list.min(), pair_length_list.mean()) - # constract poly point_num = len(point_pair_list) * 2 point_list = [0] * point_num for idx, point_pair in enumerate(point_pair_list): diff --git a/ppocr/utils/pgnet_dict.txt b/ppocr/utils/pgnet_dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..b52d16e64f1004e1fceccac2280bc6f6eabd6af3 --- /dev/null +++ b/ppocr/utils/pgnet_dict.txt @@ -0,0 +1,36 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T +U +V +W +X +Y +Z \ No newline at end of file diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py new file mode 100755 index 0000000000000000000000000000000000000000..1a92c4ab518a1c01d9d282443c09f7e3a7ecf008 --- /dev/null +++ b/tools/infer/predict_e2e.py @@ -0,0 +1,168 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import numpy as np +import time +import sys + +import tools.infer.utility as utility +from ppocr.utils.logging import get_logger +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.data import create_operators, transform +from ppocr.postprocess import build_post_process + +logger = get_logger() + + +class TextE2e(object): + def __init__(self, args): + self.args = args + self.e2e_algorithm = args.e2e_algorithm + pre_process_list = [{ + 'E2EResizeForTest': { + 'max_side_len': 768, + 'valid_set': 'totaltext' + } + }, { + 'NormalizeImage': { + 'std': [0.229, 0.224, 0.225], + 'mean': [0.485, 0.456, 0.406], + 'scale': '1./255.', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': ['image', 'shape'] + } + }] + postprocess_params = {} + if self.e2e_algorithm == "PGNet": + pre_process_list[0] = { + 'E2EResizeForTest': { + 'max_side_len': args.e2e_limit_side_len, + 'valid_set': 'totaltext' + } + } + postprocess_params['name'] = 'PGPostProcess' + postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh + postprocess_params["character_dict_path"] = args.e2e_char_dict_path + postprocess_params["valid_set"] = args.e2e_pgnet_valid_set + self.e2e_pgnet_polygon = args.e2e_pgnet_polygon + if self.e2e_pgnet_polygon: + postprocess_params["expand_scale"] = 1.2 + postprocess_params["shrink_ratio_of_width"] = 0.2 + else: + postprocess_params["expand_scale"] = 1.0 + postprocess_params["shrink_ratio_of_width"] = 0.3 + else: + logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm)) + sys.exit(0) + + self.preprocess_op = create_operators(pre_process_list) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( + args, 'e2e', logger) # paddle.jit.load(args.det_model_dir) + # self.predictor.eval() + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + box = self.clip_det_res(box, img_height, img_width) + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def __call__(self, img): + ori_im = img.copy() + data = {'image': img} + data = transform(data, self.preprocess_op) + img, shape_list = data + if img is None: + return None, 0 + img = np.expand_dims(img, axis=0) + print(img.shape) + shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + + preds = {} + if self.e2e_algorithm == 'PGNet': + preds['f_score'] = outputs[0] + preds['f_border'] = outputs[1] + preds['f_direction'] = outputs[2] + preds['f_char'] = outputs[3] + else: + raise NotImplementedError + + post_result = self.postprocess_op(preds, shape_list) + points, strs = post_result['points'], post_result['strs'] + dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) + elapse = time.time() - starttime + return dt_boxes, strs, elapse + + +if __name__ == "__main__": + args = utility.parse_args() + image_file_list = get_image_file_list(args.image_dir) + text_detector = TextE2e(args) + count = 0 + total_time = 0 + draw_img_save = "./inference_results" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + for image_file in image_file_list: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + points, strs, elapse = text_detector(img) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + src_im = utility.draw_e2e_res(points, strs, image_file) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, + "e2e_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) + if count > 1: + logger.info("Avg Time: {}".format(total_time / (count - 1))) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a4a91efdd2ec04e2a2959c77a444549ad413c13d..9aa0afed635481859cd31d461a97c451ca72acdc 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -74,6 +74,21 @@ def parse_args(): "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") parser.add_argument("--drop_score", type=float, default=0.5) + # params for e2e + parser.add_argument("--e2e_algorithm", type=str, default='PGNet') + parser.add_argument("--e2e_model_dir", type=str) + parser.add_argument("--e2e_limit_side_len", type=float, default=768) + parser.add_argument("--e2e_limit_type", type=str, default='max') + + # PGNet parmas + parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) + parser.add_argument( + "--e2e_char_dict_path", + type=str, + default="./ppocr/utils/pgnet_dict.txt") + parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') + parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False) + # params for text classifier parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--cls_model_dir", type=str) @@ -93,8 +108,10 @@ def create_predictor(args, mode, logger): model_dir = args.det_model_dir elif mode == 'cls': model_dir = args.cls_model_dir - else: + elif mode == 'rec': model_dir = args.rec_model_dir + else: + model_dir = args.e2e_model_dir if model_dir is None: logger.info("not find {} model file path {}".format(mode, model_dir)) @@ -147,6 +164,22 @@ def create_predictor(args, mode, logger): return predictor, input_tensor, output_tensors +def draw_e2e_res(dt_boxes, strs, img_path): + src_im = cv2.imread(img_path) + for box, str in zip(dt_boxes, strs): + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + cv2.putText( + src_im, + str, + org=(int(box[0, 0, 0]), int(box[0, 0, 1])), + fontFace=cv2.FONT_HERSHEY_COMPLEX, + fontScale=0.7, + color=(0, 255, 0), + thickness=1) + return src_im + + def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index c40b8e02341afd8d1204aa5335bb7b0963e5899a..b7503adb94eb797d4fb12cf47b377fa72d02158b 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -71,7 +71,8 @@ def main(): init_model(config, model, logger) # build post process - post_process_class = build_post_process(config['PostProcess']) + post_process_class = build_post_process(config['PostProcess'], + global_config) # create data ops transforms = [] diff --git a/train_data/total_text/train/poly/2.txt b/train_data/total_text/train/poly/2.txt new file mode 100644 index 0000000000000000000000000000000000000000..961d9680b203a498405fb2276e03adcc9ba3ec8c --- /dev/null +++ b/train_data/total_text/train/poly/2.txt @@ -0,0 +1,2 @@ +2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza +2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA diff --git a/train_data/total_text/train/rgb/2.jpg b/train_data/total_text/train/rgb/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3bc7a06e911ef87c0831e779d20e44b6d2bbea5 Binary files /dev/null and b/train_data/total_text/train/rgb/2.jpg differ