diff --git a/StyleText/engine/text_drawers.py b/StyleText/engine/text_drawers.py index aeec75c3378f91b64b4387ef16971165f0b80ebe..20375c13613f40c298ec83ff8fddf0e8fb73a9b0 100644 --- a/StyleText/engine/text_drawers.py +++ b/StyleText/engine/text_drawers.py @@ -66,6 +66,7 @@ class StdTextDrawer(object): corpus_list.append(corpus[0:i]) text_input_list.append(text_input) corpus = corpus[i:] + i = 0 break draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font) char_x += char_size @@ -78,7 +79,6 @@ class StdTextDrawer(object): corpus_list.append(corpus[0:i]) text_input_list.append(text_input) - corpus = corpus[i:] break return corpus_list, text_input_list diff --git a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml index 791b34cf5785d81a0f1346c0ef1ad4485ed3fee8..27ba4fd70b9a7ee7d4d905b3948f6cbf2b7e9469 100644 --- a/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -17,7 +17,7 @@ Global: character_type: ch max_text_length: 25 infer_mode: false - use_space_char: false + use_space_char: true distributed: true save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt @@ -27,28 +27,29 @@ Optimizer: beta1: 0.9 beta2: 0.999 lr: - name: Cosine - learning_rate: 0.0005 + name: Piecewise + decay_epochs : [700, 800] + values : [0.001, 0.0001] warmup_epoch: 5 regularizer: name: L2 - factor: 1.0e-05 + factor: 2.0e-05 + Architecture: + model_type: &model_type "rec" name: DistillationModel algorithm: Distillation Models: - Student: + Teacher: pretrained: freeze_params: false return_all_feats: true - model_type: rec + model_type: *model_type algorithm: CRNN Transform: Backbone: - name: MobileNetV3 + name: MobileNetV1Enhance scale: 0.5 - model_name: small - small_stride: [1, 2, 2, 2] Neck: name: SequenceEncoder encoder_type: rnn @@ -56,19 +57,17 @@ Architecture: Head: name: CTCHead mid_channels: 96 - fc_decay: 0.00001 - Teacher: + fc_decay: 0.00002 + Student: pretrained: freeze_params: false return_all_feats: true - model_type: rec + model_type: *model_type algorithm: CRNN Transform: Backbone: - name: MobileNetV3 + name: MobileNetV1Enhance scale: 0.5 - model_name: small - small_stride: [1, 2, 2, 2] Neck: name: SequenceEncoder encoder_type: rnn @@ -76,7 +75,7 @@ Architecture: Head: name: CTCHead mid_channels: 96 - fc_decay: 0.00001 + fc_decay: 0.00002 Loss: diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index 100b107a1deb1ce9932c9cefa50659c060f5803e..e9c1a8d31110ef20dd66be28d78b1e866fcd85ae 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT from ppocr.data import build_dataloader +def export_single_model(quanter, model, infer_shape, save_path, logger): + quanter.save_quantized_model( + model, + save_path, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + logger.info('inference QAT model is saved to {}'.format(save_path)) + + def main(): ############################################################################################################ # 1. quantization configs @@ -76,7 +87,14 @@ def main(): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) # get QAT model @@ -92,25 +110,30 @@ def main(): # build dataloader valid_dataloader = build_dataloader(config, 'Eval', device, logger) + use_srn = config['Architecture']['algorithm'] == "SRN" + model_type = config['Architecture']['model_type'] # start eval metirc = program.eval(model, valid_dataloader, post_process_class, - eval_class) + eval_class, model_type, use_srn) + logger.info('metric eval ***************') - for k, v in metirc.items(): + for k, v in metric.items(): logger.info('{}:{}'.format(k, v)) - save_path = '{}/inference'.format(config['Global']['save_inference_dir']) infer_shape = [3, 32, 100] if config['Architecture'][ 'model_type'] != "det" else [3, 640, 640] - quanter.save_quantized_model( - model, - save_path, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) - logger.info('inference QAT model is saved to {}'.format(save_path)) + save_path = config["Global"]["save_inference_dir"] + + arch_config = config["Architecture"] + if arch_config["algorithm"] in ["Distillation", ]: # distillation model + for idx, name in enumerate(model.model_name_list): + sub_model_save_path = os.path.join(save_path, name, "inference") + export_single_model(quanter, model.model_list[idx], infer_shape, + sub_model_save_path, logger) + else: + save_path = os.path.join(save_path, "inference") + export_single_model(quanter, model, infer_shape, save_path, logger) if __name__ == "__main__": diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 315e3b4321a544e77795c43d493873fcf46e1930..37aab68a0e88afce54e10fb6248c73684b58d808 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num model = build_model(config['Architecture']) + quanter = QAT(config=quant_config, act_preprocess=PACT) + quanter.quantize(model) + if config['Global']['distributed']: model = paddle.DataParallel(model) @@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer): logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) - quanter = QAT(config=quant_config, act_preprocess=PACT) - quanter.quantize(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md new file mode 100644 index 0000000000000000000000000000000000000000..b561f718491011e8dddcd44e66bfd6da62101ba6 --- /dev/null +++ b/doc/doc_ch/knowledge_distillation.md @@ -0,0 +1,251 @@ +# 知识蒸馏 + + +## 1. 简介 + +### 1.1 知识蒸馏介绍 + +近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。 + +在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。 + +深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。 + +此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。 + +### 1.2 PaddleOCR知识蒸馏简介 + +无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。 + +对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。 + +在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。 + +PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点: +- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。 +- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合 +- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。 + + +通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。 + + + +## 2. 配置文件解析 + +在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。 + +下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。 + +### 2.1 识别配置文件解析 + +配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)。 + +#### 2.1.1 模型结构 + +知识蒸馏任务中,模型结构配置如下所示。 + +```yaml +Architecture: + model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与 + name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构 + algorithm: Distillation # 算法名称 + Models: # 模型,包含子网络的配置信息 + Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数 + pretrained: # 该子网络是否需要加载预训练模型 + freeze_params: false # 是否需要固定参数 + return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出 + model_type: *model_type # 模型类别 + algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致 + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数 + pretrained: # 下面的组网参数同上 + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 +``` + +当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。 + +```yaml +Architecture: + model_type: &model_type "rec" + name: DistillationModel + algorithm: Distillation + Models: + Teacher: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + Student: + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 + Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同 + pretrained: + freeze_params: false + return_all_feats: true + model_type: *model_type + algorithm: CRNN + Transform: + Backbone: + name: MobileNetV1Enhance + scale: 0.5 + Neck: + name: SequenceEncoder + encoder_type: rnn + hidden_size: 64 + Head: + name: CTCHead + mid_channels: 96 + fc_decay: 0.00002 +``` + +最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。 + +蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。 + +最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。 + +在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。 + +```json +{ + "Teacher": { + "backbone_out": tensor, + "neck_out": tensor, + "head_out": tensor, + }, + "Student": { + "backbone_out": tensor, + "neck_out": tensor, + "head_out": tensor, + } +} +``` + +#### 2.1.2 损失函数 + +知识蒸馏任务中,损失函数配置如下所示。 + +```yaml +Loss: + name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类 + loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数 + - DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss + weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段 + model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss + key: head_out # 取子网络输出dict中,该key对应的tensor + - DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss + weight: 1.0 # 权重 + act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None + model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充 + - ["Student", "Teacher"] + key: head_out # 取子网络输出dict中,该key对应的tensor + - DistillationDistanceLoss: # 蒸馏的距离损失函数 + weight: 1.0 # 权重 + mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1 + model_name_pairs: # 用于计算distance loss的子网络名称对 + - ["Student", "Teacher"] + key: backbone_out # 取子网络输出dict中,该key对应的tensor +``` + +上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。 + +以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。 + +- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。 +- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。 +- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。 + +关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。 + + +#### 2.1.3 后处理 + +知识蒸馏任务中,后处理配置如下所示。 + +```yaml +PostProcess: + name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类 + model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码 + key: head_out # 取子网络输出dict中,该key对应的tensor +``` + +以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。 + +关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128) + + +#### 2.1.4 指标计算 + +知识蒸馏任务中,指标计算配置如下所示。 + +```yaml +Metric: + name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类 + base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标 + main_indicator: acc # 指标的名称 + key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准 +``` + +以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。 + +关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。 + + +### 2.2 检测配置文件解析 + +* coming soon! diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index ff084a725a27a909fcc1b29d7dc3b309fa0623a2..52194eb964f7a7fd159cc1a42b73d280f8ee5fb4 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,6 +23,7 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .randaugment import RandAugment +from .copy_paste import CopyPaste from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf62e2a3d813671551efa1a76c03754b1b764f5 --- /dev/null +++ b/ppocr/data/imaug/copy_paste.py @@ -0,0 +1,166 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import cv2 +import random +import numpy as np +from PIL import Image +from shapely.geometry import Polygon + +from ppocr.data.imaug.iaa_augment import IaaAugment +from ppocr.data.imaug.random_crop_data import is_poly_outside_rect +from tools.infer.utility import get_rotate_crop_image + + +class CopyPaste(object): + def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs): + self.ext_data_num = 1 + self.objects_paste_ratio = objects_paste_ratio + self.limit_paste = limit_paste + augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}] + self.aug = IaaAugment(augmenter_args) + + def __call__(self, data): + src_img = data['image'] + src_polys = data['polys'].tolist() + src_ignores = data['ignore_tags'].tolist() + ext_data = data['ext_data'][0] + ext_image = ext_data['image'] + ext_polys = ext_data['polys'] + ext_ignores = ext_data['ignore_tags'] + + indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] + select_num = max( + 1, min(int(self.objects_paste_ratio * len(ext_polys)), 30)) + + random.shuffle(indexs) + select_idxs = indexs[:select_num] + select_polys = ext_polys[select_idxs] + select_ignores = ext_ignores[select_idxs] + + src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) + ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) + src_img = Image.fromarray(src_img).convert('RGBA') + for poly, tag in zip(select_polys, select_ignores): + box_img = get_rotate_crop_image(ext_image, poly) + + src_img, box = self.paste_img(src_img, box_img, src_polys) + if box is not None: + src_polys.append(box) + src_ignores.append(tag) + src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) + h, w = src_img.shape[:2] + src_polys = np.array(src_polys) + src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w) + src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) + data['image'] = src_img + data['polys'] = src_polys + data['ignore_tags'] = np.array(src_ignores) + return data + + def paste_img(self, src_img, box_img, src_polys): + box_img_pil = Image.fromarray(box_img).convert('RGBA') + src_w, src_h = src_img.size + box_w, box_h = box_img_pil.size + + angle = np.random.randint(0, 360) + box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]]) + box = rotate_bbox(box_img, box, angle)[0] + box_img_pil = box_img_pil.rotate(angle, expand=1) + box_w, box_h = box_img_pil.width, box_img_pil.height + if src_w - box_w < 0 or src_h - box_h < 0: + return src_img, None + + paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w, + src_h - box_h) + if paste_x is None: + return src_img, None + box[:, 0] += paste_x + box[:, 1] += paste_y + r, g, b, A = box_img_pil.split() + src_img.paste(box_img_pil, (paste_x, paste_y), mask=A) + + return src_img, box + + def select_coord(self, src_polys, box, endx, endy): + if self.limit_paste: + xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min( + ), box[:, 0].max(), box[:, 1].max() + for _ in range(50): + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + xmin1 = xmin + paste_x + xmax1 = xmax + paste_x + ymin1 = ymin + paste_y + ymax1 = ymax + paste_y + + num_poly_in_rect = 0 + for poly in src_polys: + if not is_poly_outside_rect(poly, xmin1, ymin1, + xmax1 - xmin1, ymax1 - ymin1): + num_poly_in_rect += 1 + break + if num_poly_in_rect == 0: + return paste_x, paste_y + return None, None + else: + paste_x = random.randint(0, endx) + paste_y = random.randint(0, endy) + return paste_x, paste_y + + +def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + +def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + +def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + +def rotate_bbox(img, text_polys, angle, scale=1): + """ + from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py + Args: + img: np.ndarray + text_polys: np.ndarray N*4*2 + angle: int + scale: int + + Returns: + + """ + w = img.shape[1] + h = img.shape[0] + + rangle = np.deg2rad(angle) + nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) + nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) + rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale) + rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) + rot_mat[0, 2] += rot_move[0] + rot_mat[1, 2] += rot_move[1] + + # ---------------------- rotate box ---------------------- + rot_text_polys = list() + for bbox in text_polys: + point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) + point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) + point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) + point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) + rot_text_polys.append([point1, point2, point3, point4]) + return np.array(rot_text_polys, dtype=np.float32) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 8f8fcb4dbdf3c68587875b50cb30a834a3943216..ce9e1b38675ae8df4a2e83b88c1adae4476a10b5 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -14,6 +14,7 @@ import numpy as np import os import random +import traceback from paddle.io import Dataset from .imaug import transform, create_operators @@ -69,6 +70,36 @@ class SimpleDataSet(Dataset): random.shuffle(self.data_lines) return + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:2] + ext_data = [] + + while len(ext_data) < ext_data_num: + file_idx = self.data_idx_order_list[np.random.randint(self.__len__( + ))] + data_line = self.data_lines[file_idx] + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + continue + with open(data['img_path'], 'rb') as f: + img = f.read() + data['image'] = img + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] data_line = self.data_lines[file_idx] @@ -84,11 +115,13 @@ class SimpleDataSet(Dataset): with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) - except Exception as e: + except: + error_meg = traceback.format_exc() self.logger.error( "When parsing line {}, error happened with msg: {}".format( - data_line, e)) + data_line, error_meg)) outs = None if outs is None: # during evaluation, we should fix the idx to get same results for many times of evaluation. diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 13b70b203371b3be58ee82c6808d744bf6098333..f4fe8c76be0835f55f402f35ad6a91a5ca116d88 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -12,33 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['build_backbone'] +__all__ = ["build_backbone"] def build_backbone(config, model_type): - if model_type == 'det': + if model_type == "det": from .det_mobilenet_v3 import MobileNetV3 from .det_resnet_vd import ResNet from .det_resnet_vd_sast import ResNet_SAST - support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST'] - elif model_type == 'rec' or model_type == 'cls': + support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"] + elif model_type == "rec" or model_type == "cls": from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet from .rec_resnet_fpn import ResNetFPN - support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] - elif model_type == 'e2e': + from .rec_mv1_enhance import MobileNetV1Enhance + support_dict = [ + "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN" + ] + elif model_type == "e2e": from .e2e_resnet_vd_pg import ResNet - support_dict = ['ResNet'] + support_dict = ["ResNet"] elif model_type == "table": from .table_resnet_vd import ResNet from .table_mobilenet_v3 import MobileNetV3 - support_dict = ['ResNet', 'MobileNetV3'] + support_dict = ["ResNet", "MobileNetV3"] else: raise NotImplementedError - module_name = config.pop('name') + module_name = config.pop("name") assert module_name in support_dict, Exception( - 'when model typs is {}, backbone only support {}'.format(model_type, + "when model typs is {}, backbone only support {}".format(model_type, support_dict)) module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..fe874fac1af439bfb47ba9050a61f02db302e224 --- /dev/null +++ b/ppocr/modeling/backbones/rec_mv1_enhance.py @@ -0,0 +1,256 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +from paddle import ParamAttr +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.nn.initializer import KaimingNormal +import math +import numpy as np +import paddle +from paddle import ParamAttr, reshape, transpose, concat, split +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.nn import Conv2D, BatchNorm, Linear, Dropout +from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D +from paddle.nn.initializer import KaimingNormal +import math +from paddle.nn.functional import hardswish, hardsigmoid +from paddle.regularizer import L2Decay + + +class ConvBNLayer(nn.Layer): + def __init__(self, + num_channels, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='hard_swish'): + super(ConvBNLayer, self).__init__() + + self._conv = Conv2D( + in_channels=num_channels, + out_channels=num_filters, + kernel_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self._batch_norm = BatchNorm( + num_filters, + act=act, + param_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._batch_norm(y) + return y + + +class DepthwiseSeparable(nn.Layer): + def __init__(self, + num_channels, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + dw_size=3, + padding=1, + use_se=False): + super(DepthwiseSeparable, self).__init__() + self.use_se = use_se + self._depthwise_conv = ConvBNLayer( + num_channels=num_channels, + num_filters=int(num_filters1 * scale), + filter_size=dw_size, + stride=stride, + padding=padding, + num_groups=int(num_groups * scale)) + if use_se: + self._se = SEModule(int(num_filters1 * scale)) + self._pointwise_conv = ConvBNLayer( + num_channels=int(num_filters1 * scale), + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0) + + def forward(self, inputs): + y = self._depthwise_conv(inputs) + if self.use_se: + y = self._se(y) + y = self._pointwise_conv(y) + return y + + +class MobileNetV1Enhance(nn.Layer): + def __init__(self, in_channels=3, scale=0.5, **kwargs): + super().__init__() + self.scale = scale + self.block_list = [] + + self.conv1 = ConvBNLayer( + num_channels=3, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1) + + conv2_1 = DepthwiseSeparable( + num_channels=int(32 * scale), + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale) + self.block_list.append(conv2_1) + + conv2_2 = DepthwiseSeparable( + num_channels=int(64 * scale), + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=1, + scale=scale) + self.block_list.append(conv2_2) + + conv3_1 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale) + self.block_list.append(conv3_1) + + conv3_2 = DepthwiseSeparable( + num_channels=int(128 * scale), + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=(2, 1), + scale=scale) + self.block_list.append(conv3_2) + + conv4_1 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale) + self.block_list.append(conv4_1) + + conv4_2 = DepthwiseSeparable( + num_channels=int(256 * scale), + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=(2, 1), + scale=scale) + self.block_list.append(conv4_2) + + for _ in range(5): + conv5 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + dw_size=5, + padding=2, + scale=scale, + use_se=False) + self.block_list.append(conv5) + + conv5_6 = DepthwiseSeparable( + num_channels=int(512 * scale), + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=(2, 1), + dw_size=5, + padding=2, + scale=scale, + use_se=True) + self.block_list.append(conv5_6) + + conv6 = DepthwiseSeparable( + num_channels=int(1024 * scale), + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + dw_size=5, + padding=2, + use_se=True, + scale=scale) + self.block_list.append(conv6) + + self.block_list = nn.Sequential(*self.block_list) + + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) + self.out_channels = int(1024 * scale) + + def forward(self, inputs): + y = self.conv1(inputs) + y = self.block_list(y) + y = self.pool(y) + return y + + +class SEModule(nn.Layer): + def __init__(self, channel, reduction=4): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0, + weight_attr=ParamAttr(), + bias_attr=ParamAttr()) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = F.relu(outputs) + outputs = self.conv2(outputs) + outputs = hardsigmoid(outputs) + return paddle.multiply(x=inputs, y=outputs) diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py index 78338edf67d69e32322912d75dec01ce1e63cb49..dcce6246ac64b4b84229cbd69a4dc53c658b4c7b 100644 --- a/ppocr/modeling/transforms/tps.py +++ b/ppocr/modeling/transforms/tps.py @@ -230,15 +230,8 @@ class GridGenerator(nn.Layer): def build_inv_delta_C_paddle(self, C): """ Return inv_delta_C which is needed to calculate T """ F = self.F - hat_C = paddle.zeros((F, F), dtype='float64') # F x F - for i in range(0, F): - for j in range(i, F): - if i == j: - hat_C[i, j] = 1 - else: - r = paddle.norm(C[i] - C[j]) - hat_C[i, j] = r - hat_C[j, i] = r + hat_eye = paddle.eye(F, dtype='float64') # F x F + hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye hat_C = (hat_C**2) * paddle.log(hat_C) delta_C = paddle.concat( # F+3 x F+3 [ diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 23f5401bb71a2ef50ff2ff2c3c27275d7e10b3c0..1d760e983a635dcc6b48b839ee99434c67b4378d 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -25,7 +25,7 @@ import paddle from ppocr.utils.logging import get_logger -__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] +__all__ = ['init_model', 'save_model', 'load_dygraph_params'] def _mkdir_if_not_exist(path, logger): @@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None): return best_model_dict +def load_dygraph_params(config, model, logger, optimizer): + ckp = config['Global']['checkpoints'] + if ckp and os.path.exists(ckp + ".pdparams"): + pre_best_model_dict = init_model(config, model, optimizer) + return pre_best_model_dict + else: + pm = config['Global']['pretrained_model'] + if pm is None: + return {} + if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"): + logger.info(f"The pretrained_model {pm} does not exists!") + return {} + pm = pm if pm.endswith('.pdparams') else pm + '.pdparams' + params = paddle.load(pm) + state_dict = model.state_dict() + new_state_dict = {} + for k1, k2 in zip(state_dict.keys(), params.keys()): + if list(state_dict[k1].shape) == list(params[k2].shape): + new_state_dict[k1] = params[k2] + else: + logger.info( + f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" + ) + model.set_state_dict(new_state_dict) + logger.info(f"loaded pretrained_model successful from {pm}") + return {} + + def save_model(model, optimizer, model_path, diff --git a/test/ocr_det_params.txt b/test/ocr_det_params.txt new file mode 100644 index 0000000000000000000000000000000000000000..01ac82d3d7d459ca324ec61cfcaac2386660a211 --- /dev/null +++ b/test/ocr_det_params.txt @@ -0,0 +1,35 @@ +model_name:ocr_det +python:python3.7 +gpu_list:0|0,1 +Global.auto_cast:False +Global.epoch_num:10 +Global.save_model_dir:./output/ +Global.save_inference_dir:./output/ +Train.loader.batch_size_per_card: +Global.use_gpu +Global.pretrained_model + +trainer:norm|pact +norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained +quant_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy +fpgm_train:null +distill_train:null + +eval:tools/eval.py -c configs/det/det_mv3_db.yml -o + +norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o +quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o +fpgm_export:deploy/slim/prune/export_prune_model.py +distill_export:null + +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:True|False +--cpu_threads:1|6 +--rec_batch_num:1 +--use_tensorrt:True|False +--precision:fp32|fp16|int8 +--det_model_dir +--image_dir +--save_log_path + diff --git a/test/prepare.sh b/test/prepare.sh new file mode 100644 index 0000000000000000000000000000000000000000..150682469641a784f641313d361bb921d6d9dfb8 --- /dev/null +++ b/test/prepare.sh @@ -0,0 +1,138 @@ +#!/bin/bash +FILENAME=$1 +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +MODE=$2 + +dataline=$(cat ${FILENAME}) + +# parser params +IFS=$'\n' +lines=(${dataline}) +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} +IFS=$'\n' +# The training params +model_name=$(func_parser_value "${lines[0]}") +train_model_list=$(func_parser_value "${lines[0]}") +trainer_list=$(func_parser_value "${lines[10]}") + +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer'] +MODE=$2 +# prepare pretrained weights and dataset +wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams +wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar +cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../ + +if [ ${MODE} = "lite_train_infer" ];then + # pretrain lite train data + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar + cd ./train_data/ && tar xf icdar2015_lite.tar + ln -s ./icdar2015_lite ./icdar2015 + cd ../ + epoch=10 + eval_batch_step=10 +elif [ ${MODE} = "whole_train_infer" ];then + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar + cd ./train_data/ && tar xf icdar2015.tar && cd ../ + epoch=500 + eval_batch_step=200 +elif [ ${MODE} = "whole_infer" ];then + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar + cd ./train_data/ && tar xf icdar2015_infer.tar + ln -s ./icdar2015_infer ./icdar2015 + cd ../ + epoch=10 + eval_batch_step=10 +else + rm -rf ./train_data/icdar2015 + wget -nc -P ./train_data https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + if [ ${model_name} = "ocr_det" ]; then + eval_model_name="ch_ppocr_mobile_v2.0_det_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + else + eval_model_name="ch_ppocr_mobile_v2.0_rec_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + fi +fi + + +IFS='|' +for train_model in ${train_model_list[*]}; do + if [ ${train_model} = "ocr_det" ];then + model_name="ocr_det" + yml_file="configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar + cd ./inference && tar xf ch_det_data_50.tar && cd ../ + img_dir="./inference/ch_det_data_50/all-sum-510" + data_dir=./inference/ch_det_data_50/ + data_label_file=[./inference/ch_det_data_50/test_gt_50.txt] + elif [ ${train_model} = "ocr_rec" ];then + model_name="ocr_rec" + yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar + cd ./inference && tar xf ch_rec_data_200.tar && cd ../ + img_dir="./inference/ch_rec_data_200/" + fi + + # eval + for slim_trainer in ${trainer_list[*]}; do + if [ ${slim_trainer} = "norm" ]; then + if [ ${model_name} = "ocr_det" ]; then + eval_model_name="ch_ppocr_mobile_v2.0_det_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + else + eval_model_name="ch_ppocr_mobile_v2.0_rec_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + fi + elif [ ${slim_trainer} = "pact" ]; then + if [ ${model_name} = "ocr_det" ]; then + eval_model_name="ch_ppocr_mobile_v2.0_det_quant_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_quant_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + else + eval_model_name="ch_ppocr_mobile_v2.0_rec_quant_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_quant_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + fi + elif [ ${slim_trainer} = "distill" ]; then + if [ ${model_name} = "ocr_det" ]; then + eval_model_name="ch_ppocr_mobile_v2.0_det_distill_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_distill_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + else + eval_model_name="ch_ppocr_mobile_v2.0_rec_distill_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_distill_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + fi + elif [ ${slim_trainer} = "fpgm" ]; then + if [ ${model_name} = "ocr_det" ]; then + eval_model_name="ch_ppocr_mobile_v2.0_det_prune_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + else + eval_model_name="ch_ppocr_mobile_v2.0_rec_prune_train" + wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_rec_prune_train.tar + cd ./inference && tar xf ${eval_model_name}.tar && cd ../ + fi + fi + done +done diff --git a/test/test.sh b/test/test.sh new file mode 100644 index 0000000000000000000000000000000000000000..2a27563ffaa2b1b96b58cbff89546acf7a286210 --- /dev/null +++ b/test/test.sh @@ -0,0 +1,221 @@ +#!/bin/bash +FILENAME=$1 +# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer'] +MODE=$2 + +dataline=$(cat ${FILENAME}) + +# parser params +IFS=$'\n' +lines=(${dataline}) +function func_parser_key(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[0]} + echo ${tmp} +} +function func_parser_value(){ + strs=$1 + IFS=":" + array=(${strs}) + tmp=${array[1]} + echo ${tmp} +} +function status_check(){ + last_status=$1 # the exit code + run_command=$2 + run_log=$3 + if [ $last_status -eq 0 ]; then + echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log} + else + echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log} + fi +} + +IFS=$'\n' +# The training params +model_name=$(func_parser_value "${lines[0]}") +python=$(func_parser_value "${lines[1]}") +gpu_list=$(func_parser_value "${lines[2]}") +autocast_list=$(func_parser_value "${lines[3]}") +autocast_key=$(func_parser_key "${lines[3]}") +epoch_key=$(func_parser_key "${lines[4]}") +save_model_key=$(func_parser_key "${lines[5]}") +save_infer_key=$(func_parser_key "${lines[6]}") +train_batch_key=$(func_parser_key "${lines[7]}") +train_use_gpu_key=$(func_parser_key "${lines[8]}") +pretrain_model_key=$(func_parser_key "${lines[9]}") + +trainer_list=$(func_parser_value "${lines[10]}") +norm_trainer=$(func_parser_value "${lines[11]}") +pact_trainer=$(func_parser_value "${lines[12]}") +fpgm_trainer=$(func_parser_value "${lines[13]}") +distill_trainer=$(func_parser_value "${lines[14]}") + +eval_py=$(func_parser_value "${lines[15]}") +norm_export=$(func_parser_value "${lines[16]}") +pact_export=$(func_parser_value "${lines[17]}") +fpgm_export=$(func_parser_value "${lines[18]}") +distill_export=$(func_parser_value "${lines[19]}") + +inference_py=$(func_parser_value "${lines[20]}") +use_gpu_key=$(func_parser_key "${lines[21]}") +use_gpu_list=$(func_parser_value "${lines[21]}") +use_mkldnn_key=$(func_parser_key "${lines[22]}") +use_mkldnn_list=$(func_parser_value "${lines[22]}") +cpu_threads_key=$(func_parser_key "${lines[23]}") +cpu_threads_list=$(func_parser_value "${lines[23]}") +batch_size_key=$(func_parser_key "${lines[24]}") +batch_size_list=$(func_parser_value "${lines[24]}") +use_trt_key=$(func_parser_key "${lines[25]}") +use_trt_list=$(func_parser_value "${lines[25]}") +precision_key=$(func_parser_key "${lines[26]}") +precision_list=$(func_parser_value "${lines[26]}") +model_dir_key=$(func_parser_key "${lines[27]}") +image_dir_key=$(func_parser_key "${lines[28]}") +save_log_key=$(func_parser_key "${lines[29]}") + +LOG_PATH="./test/output" +mkdir -p ${LOG_PATH} +status_log="${LOG_PATH}/results.log" + +if [ ${MODE} = "lite_train_infer" ]; then + export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/" + export epoch_num=10 +elif [ ${MODE} = "whole_infer" ]; then + export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/" + export epoch_num=10 +elif [ ${MODE} = "whole_train_infer" ]; then + export infer_img_dir="./train_data/icdar2015/text_localization/ch4_test_images/" + export epoch_num=300 +else + export infer_img_dir="./inference/ch_det_data_50/all-sum-510" + export infer_model_dir="./inference/ch_ppocr_mobile_v2.0_det_train/best_accuracy" +fi + + +function func_inference(){ + IFS='|' + _python=$1 + _script=$2 + _model_dir=$3 + _log_path=$4 + _img_dir=$5 + + # inference + for use_gpu in ${use_gpu_list[*]}; do + if [ ${use_gpu} = "False" ]; then + for use_mkldnn in ${use_mkldnn_list[*]}; do + for threads in ${cpu_threads_list[*]}; do + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}" + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${cpu_threads_key}=${threads} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True" + eval $command + status_check $? "${command}" "${status_log}" + done + done + done + else + for use_trt in ${use_trt_list[*]}; do + for precision in ${precision_list[*]}; do + if [ ${use_trt} = "False" ] && [ ${precision} != "fp32" ]; then + continue + fi + for batch_size in ${batch_size_list[*]}; do + _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}" + command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_trt_key}=${use_trt} ${precision_key}=${precision} ${model_dir_key}=${_model_dir} ${batch_size_key}=${batch_size} ${image_dir_key}=${_img_dir} ${save_log_key}=${_save_log_path} --benchmark=True" + eval $command + status_check $? "${command}" "${status_log}" + done + done + done + fi + done +} + +if [ ${MODE} != "infer" ]; then + +IFS="|" +for gpu in ${gpu_list[*]}; do + train_use_gpu=True + if [ ${gpu} = "-1" ];then + train_use_gpu=False + env="" + elif [ ${#gpu} -le 1 ];then + env="export CUDA_VISIBLE_DEVICES=${gpu}" + elif [ ${#gpu} -le 15 ];then + IFS="," + array=(${gpu}) + env="export CUDA_VISIBLE_DEVICES=${array[0]}" + IFS="|" + else + IFS=";" + array=(${gpu}) + ips=${array[0]} + gpu=${array[1]} + IFS="|" + fi + for autocast in ${autocast_list[*]}; do + for trainer in ${trainer_list[*]}; do + if [ ${trainer} = "pact" ]; then + run_train=${pact_trainer} + run_export=${pact_export} + elif [ ${trainer} = "fpgm" ]; then + run_train=${fpgm_trainer} + run_export=${fpgm_export} + elif [ ${trainer} = "distill" ]; then + run_train=${distill_trainer} + run_export=${distill_export} + else + run_train=${norm_trainer} + run_export=${norm_export} + fi + + if [ ${run_train} = "null" ]; then + continue + fi + if [ ${run_export} = "null" ]; then + continue + fi + + save_log="${LOG_PATH}/${trainer}_gpus_${gpu}_autocast_${autocast}" + if [ ${#gpu} -le 2 ];then # epoch_num #TODO + cmd="${python} ${run_train} ${train_use_gpu_key}=${train_use_gpu} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log} " + elif [ ${#gpu} -le 15 ];then + cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}" + else + cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${autocast_key}=${autocast} ${epoch_key}=${epoch_num} ${save_model_key}=${save_log}" + fi + # run train + eval $cmd + status_check $? "${cmd}" "${status_log}" + + # run eval + eval_cmd="${python} ${eval_py} ${save_model_key}=${save_log} ${pretrain_model_key}=${save_log}/latest" + eval $eval_cmd + status_check $? "${eval_cmd}" "${status_log}" + + # run export model + save_infer_path="${save_log}" + export_cmd="${python} ${run_export} ${save_model_key}=${save_log} ${pretrain_model_key}=${save_log}/latest ${save_infer_key}=${save_infer_path}" + eval $export_cmd + status_check $? "${export_cmd}" "${status_log}" + + #run inference + save_infer_path="${save_log}" + func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}" + done + done +done + +else + save_infer_path="${LOG_PATH}/${MODE}" + run_export=${norm_export} + export_cmd="${python} ${run_export} ${save_model_key}=${save_infer_path} ${pretrain_model_key}=${infer_model_dir} ${save_infer_key}=${save_infer_path}" + eval $export_cmd + status_check $? "${export_cmd}" "${status_log}" + + #run inference + func_inference "${python}" "${inference_py}" "${save_infer_path}" "${LOG_PATH}" "${infer_img_dir}" +fi diff --git a/test1/MANIFEST.in b/test1/MANIFEST.in index 2961e722b7cebe8e1912be2dd903fcdecb694019..203e97d3104e5be44a4e58f87fbae08d59d3f537 100644 --- a/test1/MANIFEST.in +++ b/test1/MANIFEST.in @@ -5,5 +5,5 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py recursive-include tools/infer *.py -recursive-include ppstructure *.py +recursive-include test1 *.py diff --git a/test1/paddlestructure.py b/test1/paddlestructure.py index d8199101bb2d97b7bad063c4bec66eeea656c1fa..bbe69d40460b72def4e4098001638414f80e4f19 100644 --- a/test1/paddlestructure.py +++ b/test1/paddlestructure.py @@ -146,23 +146,3 @@ def main(): logger.info(item['res']) save_res(result, save_folder, img_name) logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) - -if __name__ == '__main__': - table_engine = PaddleStructure(show_log=True) - - img_path = '../test/test_imgs/PMC1173095_006_00.png' - img = cv2.imread(img_path) - result = table_engine(img) - save_res(result, '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table', - os.path.basename(img_path).split('.')[0]) - - for line in result: - print(line) - - from PIL import Image - - font_path = '../doc/fonts/simfang.ttf' - image = Image.open(img_path).convert('RGB') - im_show = draw_result(image, result, font_path=font_path) - im_show = Image.fromarray(im_show) - im_show.save('result.jpg') \ No newline at end of file diff --git a/test1/setup.py b/test1/setup.py index 0b092c49a4db98def28a7c2942993806b0ffc27c..1d07517790716717d65088ccd854a2557adc8888 100644 --- a/test1/setup.py +++ b/test1/setup.py @@ -20,8 +20,6 @@ import shutil with open('../requirements.txt', encoding="utf-8-sig") as f: requirements = f.readlines() requirements.append('tqdm') - requirements.append('layoutparser') - requirements.append('iopath') def readme(): diff --git a/tools/infer/benchmark_utils.py b/tools/infer/benchmark_utils.py deleted file mode 100644 index 1a241d063368d19567e253bf1dada09801d468bc..0000000000000000000000000000000000000000 --- a/tools/infer/benchmark_utils.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import time -import logging - -import paddle -import paddle.inference as paddle_infer - -from pathlib import Path - -CUR_DIR = os.path.dirname(os.path.abspath(__file__)) - - -class PaddleInferBenchmark(object): - def __init__(self, - config, - model_info: dict={}, - data_info: dict={}, - perf_info: dict={}, - resource_info: dict={}, - save_log_path: str="", - **kwargs): - """ - Construct PaddleInferBenchmark Class to format logs. - args: - config(paddle.inference.Config): paddle inference config - model_info(dict): basic model info - {'model_name': 'resnet50' - 'precision': 'fp32'} - data_info(dict): input data info - {'batch_size': 1 - 'shape': '3,224,224' - 'data_num': 1000} - perf_info(dict): performance result - {'preprocess_time_s': 1.0 - 'inference_time_s': 2.0 - 'postprocess_time_s': 1.0 - 'total_time_s': 4.0} - resource_info(dict): - cpu and gpu resources - {'cpu_rss': 100 - 'gpu_rss': 100 - 'gpu_util': 60} - """ - # PaddleInferBenchmark Log Version - self.log_version = 1.0 - - # Paddle Version - self.paddle_version = paddle.__version__ - self.paddle_commit = paddle.__git_commit__ - paddle_infer_info = paddle_infer.get_version() - self.paddle_branch = paddle_infer_info.strip().split(': ')[-1] - - # model info - self.model_info = model_info - - # data info - self.data_info = data_info - - # perf info - self.perf_info = perf_info - - try: - self.model_name = model_info['model_name'] - self.precision = model_info['precision'] - - self.batch_size = data_info['batch_size'] - self.shape = data_info['shape'] - self.data_num = data_info['data_num'] - - self.preprocess_time_s = round(perf_info['preprocess_time_s'], 4) - self.inference_time_s = round(perf_info['inference_time_s'], 4) - self.postprocess_time_s = round(perf_info['postprocess_time_s'], 4) - self.total_time_s = round(perf_info['total_time_s'], 4) - except: - self.print_help() - raise ValueError( - "Set argument wrong, please check input argument and its type") - - # conf info - self.config_status = self.parse_config(config) - self.save_log_path = save_log_path - # mem info - if isinstance(resource_info, dict): - self.cpu_rss_mb = int(resource_info.get('cpu_rss_mb', 0)) - self.gpu_rss_mb = int(resource_info.get('gpu_rss_mb', 0)) - self.gpu_util = round(resource_info.get('gpu_util', 0), 2) - else: - self.cpu_rss_mb = 0 - self.gpu_rss_mb = 0 - self.gpu_util = 0 - - # init benchmark logger - self.benchmark_logger() - - def benchmark_logger(self): - """ - benchmark logger - """ - # Init logger - FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - log_output = f"{self.save_log_path}/{self.model_name}.log" - Path(f"{self.save_log_path}").mkdir(parents=True, exist_ok=True) - logging.basicConfig( - level=logging.INFO, - format=FORMAT, - handlers=[ - logging.FileHandler( - filename=log_output, mode='w'), - logging.StreamHandler(), - ]) - self.logger = logging.getLogger(__name__) - self.logger.info( - f"Paddle Inference benchmark log will be saved to {log_output}") - - def parse_config(self, config) -> dict: - """ - parse paddle predictor config - args: - config(paddle.inference.Config): paddle inference config - return: - config_status(dict): dict style config info - """ - config_status = {} - config_status['runtime_device'] = "gpu" if config.use_gpu() else "cpu" - config_status['ir_optim'] = config.ir_optim() - config_status['enable_tensorrt'] = config.tensorrt_engine_enabled() - config_status['precision'] = self.precision - config_status['enable_mkldnn'] = config.mkldnn_enabled() - config_status[ - 'cpu_math_library_num_threads'] = config.cpu_math_library_num_threads( - ) - return config_status - - def report(self, identifier=None): - """ - print log report - args: - identifier(string): identify log - """ - if identifier: - identifier = f"[{identifier}]" - else: - identifier = "" - - self.logger.info("\n") - self.logger.info( - "---------------------- Paddle info ----------------------") - self.logger.info(f"{identifier} paddle_version: {self.paddle_version}") - self.logger.info(f"{identifier} paddle_commit: {self.paddle_commit}") - self.logger.info(f"{identifier} paddle_branch: {self.paddle_branch}") - self.logger.info(f"{identifier} log_api_version: {self.log_version}") - self.logger.info( - "----------------------- Conf info -----------------------") - self.logger.info( - f"{identifier} runtime_device: {self.config_status['runtime_device']}" - ) - self.logger.info( - f"{identifier} ir_optim: {self.config_status['ir_optim']}") - self.logger.info(f"{identifier} enable_memory_optim: {True}") - self.logger.info( - f"{identifier} enable_tensorrt: {self.config_status['enable_tensorrt']}" - ) - self.logger.info( - f"{identifier} enable_mkldnn: {self.config_status['enable_mkldnn']}") - self.logger.info( - f"{identifier} cpu_math_library_num_threads: {self.config_status['cpu_math_library_num_threads']}" - ) - self.logger.info( - "----------------------- Model info ----------------------") - self.logger.info(f"{identifier} model_name: {self.model_name}") - self.logger.info(f"{identifier} precision: {self.precision}") - self.logger.info( - "----------------------- Data info -----------------------") - self.logger.info(f"{identifier} batch_size: {self.batch_size}") - self.logger.info(f"{identifier} input_shape: {self.shape}") - self.logger.info(f"{identifier} data_num: {self.data_num}") - self.logger.info( - "----------------------- Perf info -----------------------") - self.logger.info( - f"{identifier} cpu_rss(MB): {self.cpu_rss_mb}, gpu_rss(MB): {self.gpu_rss_mb}, gpu_util: {self.gpu_util}%" - ) - self.logger.info( - f"{identifier} total time spent(s): {self.total_time_s}") - self.logger.info( - f"{identifier} preprocess_time(ms): {round(self.preprocess_time_s*1000, 1)}, inference_time(ms): {round(self.inference_time_s*1000, 1)}, postprocess_time(ms): {round(self.postprocess_time_s*1000, 1)}" - ) - - def print_help(self): - """ - print function help - """ - print("""Usage: - ==== Print inference benchmark logs. ==== - config = paddle.inference.Config() - model_info = {'model_name': 'resnet50' - 'precision': 'fp32'} - data_info = {'batch_size': 1 - 'shape': '3,224,224' - 'data_num': 1000} - perf_info = {'preprocess_time_s': 1.0 - 'inference_time_s': 2.0 - 'postprocess_time_s': 1.0 - 'total_time_s': 4.0} - resource_info = {'cpu_rss_mb': 100 - 'gpu_rss_mb': 100 - 'gpu_util': 60} - log = PaddleInferBenchmark(config, model_info, data_info, perf_info, resource_info) - log('Test') - """) - - def __call__(self, identifier=None): - """ - __call__ - args: - identifier(string): identify log - """ - self.report(identifier) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 0037b226df8e1de8edbdb7668e349925a942e8b9..1a0cf3567305f22c07b547bdc56dbdfc5a88d737 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -48,8 +48,6 @@ class TextClassifier(object): self.predictor, self.input_tensor, self.output_tensors, _ = \ utility.create_predictor(args, 'cls', logger) - self.cls_times = utility.Timer() - def resize_norm_img(self, img): imgC, imgH, imgW = self.cls_image_shape h = img.shape[0] @@ -85,35 +83,28 @@ class TextClassifier(object): cls_res = [['', 0.0]] * img_num batch_num = self.cls_batch_num elapse = 0 - self.cls_times.total_time.start() for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 + starttime = time.time() for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) - self.cls_times.preprocess_time.start() for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]]) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - self.cls_times.preprocess_time.end() - self.cls_times.inference_time.start() self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() - self.cls_times.inference_time.end() - self.cls_times.postprocess_time.start() self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) - self.cls_times.postprocess_time.end() elapse += time.time() - starttime for rno in range(len(cls_result)): label, score = cls_result[rno] @@ -121,9 +112,7 @@ class TextClassifier(object): if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) - self.cls_times.total_time.end() - self.cls_times.img_num += img_num - elapse = self.cls_times.total_time.value() + elapse = time.time() - starttime return img_list, cls_res, elapse diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 1b52e717ca1edbbd9ede31260e47ec8973270d3f..bbf3659cbc6c34e550ba08440312a09da6362df0 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -31,8 +31,6 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process -# import tools.infer.benchmark_utils as benchmark_utils - logger = get_logger() @@ -100,6 +98,24 @@ class TextDetector(object): self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( args, 'det', logger) + if args.benchmark: + import auto_log + pid = os.getpid() + self.autolog = auto_log.AutoLogger( + model_name="det", + model_precision=args.precision, + batch_size=1, + data_shape="dynamic", + save_path=args.save_log_path, + inference_config=self.config, + pids=pid, + process_name=None, + gpu_ids=0, + time_keys=[ + 'preprocess_time', 'inference_time', 'postprocess_time' + ], + warmup=10) + def order_points_clockwise(self, pts): """ reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py @@ -158,6 +174,10 @@ class TextDetector(object): data = {'image': img} st = time.time() + + if args.benchmark: + self.autolog.times.start() + data = transform(data, self.preprocess_op) img, shape_list = data if img is None: @@ -166,12 +186,17 @@ class TextDetector(object): shape_list = np.expand_dims(shape_list, axis=0) img = img.copy() + if args.benchmark: + self.autolog.times.stamp() + self.input_tensor.copy_from_cpu(img) self.predictor.run() outputs = [] for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) + if args.benchmark: + self.autolog.times.stamp() preds = {} if self.det_algorithm == "EAST": @@ -187,7 +212,7 @@ class TextDetector(object): else: raise NotImplementedError - self.predictor.try_shrink_memory() + #self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: @@ -195,6 +220,8 @@ class TextDetector(object): else: dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) + if args.benchmark: + self.autolog.times.end(stamp=True) et = time.time() return dt_boxes, et - st @@ -212,8 +239,6 @@ if __name__ == "__main__": for i in range(10): res = text_detector(img) - cpu_mem, gpu_mem, gpu_util = 0, 0, 0 - if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: @@ -237,3 +262,6 @@ if __name__ == "__main__": "det_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) logger.info("The visualized image saved in {}".format(img_path)) + + if args.benchmark: + text_detector.autolog.report() diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 0d847046530c02c9b0591bb4b379fd7ddeac1263..cf671403dcd254310774a4dd27d3c6c5e9fb0886 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -28,7 +28,6 @@ import traceback import paddle import tools.infer.utility as utility -import tools.infer.benchmark_utils as benchmark_utils from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif @@ -66,8 +65,6 @@ class TextRecognizer(object): self.predictor, self.input_tensor, self.output_tensors, self.config = \ utility.create_predictor(args, 'rec', logger) - self.rec_times = utility.Timer() - def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape assert imgC == img.shape[2] @@ -168,14 +165,13 @@ class TextRecognizer(object): width_list.append(img.shape[1] / float(img.shape[0])) # Sorting can speed up the recognition process indices = np.argsort(np.array(width_list)) - self.rec_times.total_time.start() rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num + st = time.time() for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] max_wh_ratio = 0 - self.rec_times.preprocess_time.start() for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h @@ -216,23 +212,18 @@ class TextRecognizer(object): gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list, ] - self.rec_times.preprocess_time.end() - self.rec_times.inference_time.start() input_names = self.predictor.get_input_names() for i in range(len(input_names)): input_tensor = self.predictor.get_input_handle(input_names[ i]) input_tensor.copy_from_cpu(inputs[i]) self.predictor.run() - self.rec_times.inference_time.end() outputs = [] for output_tensor in self.output_tensors: output = output_tensor.copy_to_cpu() outputs.append(output) preds = {"predict": outputs[2]} else: - self.rec_times.preprocess_time.end() - self.rec_times.inference_time.start() self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() @@ -241,15 +232,11 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - self.rec_times.inference_time.end() - self.rec_times.postprocess_time.start() rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] - self.rec_times.postprocess_time.end() - self.rec_times.img_num += int(norm_img_batch.shape[0]) - self.rec_times.total_time.end() - return rec_res, self.rec_times.total_time.value() + + return rec_res, time.time() - st def main(args): @@ -278,12 +265,6 @@ def main(args): img_list.append(img) try: rec_res, _ = text_recognizer(img_list) - if args.benchmark: - cm, gm, gu = utility.get_current_memory_mb(0) - cpu_mem += cm - gpu_mem += gm - gpu_util += gu - count += 1 except Exception as E: logger.info(traceback.format_exc()) @@ -292,38 +273,6 @@ def main(args): for ino in range(len(img_list)): logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])) - if args.benchmark: - mems = { - 'cpu_rss_mb': cpu_mem / count, - 'gpu_rss_mb': gpu_mem / count, - 'gpu_util': gpu_util * 100 / count - } - else: - mems = None - logger.info("The predict time about recognizer module is as follows: ") - rec_time_dict = text_recognizer.rec_times.report(average=True) - rec_model_name = args.rec_model_dir - - if args.benchmark: - # construct log information - model_info = { - 'model_name': args.rec_model_dir.split('/')[-1], - 'precision': args.precision - } - data_info = { - 'batch_size': args.rec_batch_num, - 'shape': 'dynamic_shape', - 'data_num': rec_time_dict['img_num'] - } - perf_info = { - 'preprocess_time_s': rec_time_dict['preprocess_time'], - 'inference_time_s': rec_time_dict['inference_time'], - 'postprocess_time_s': rec_time_dict['postprocess_time'], - 'total_time_s': rec_time_dict['total_time'] - } - benchmark_log = benchmark_utils.PaddleInferBenchmark( - text_recognizer.config, model_info, data_info, perf_info, mems) - benchmark_log("Rec") if __name__ == "__main__": diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index c008f9679684e2433859cd104261aeff56b410a2..715bd3fa9d596dd60f7f789f3e367734ffec608b 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -33,8 +33,7 @@ import tools.infer.predict_det as predict_det import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt, get_current_memory_mb -import tools.infer.benchmark_utils as benchmark_utils +from tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image logger = get_logger() @@ -50,39 +49,6 @@ class TextSystem(object): if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) - def get_rotate_crop_image(self, img, points): - ''' - img_height, img_width = img.shape[0:2] - left = int(np.min(points[:, 0])) - right = int(np.max(points[:, 0])) - top = int(np.min(points[:, 1])) - bottom = int(np.max(points[:, 1])) - img_crop = img[top:bottom, left:right, :].copy() - points[:, 0] = points[:, 0] - left - points[:, 1] = points[:, 1] - top - ''' - img_crop_width = int( - max( - np.linalg.norm(points[0] - points[1]), - np.linalg.norm(points[2] - points[3]))) - img_crop_height = int( - max( - np.linalg.norm(points[0] - points[3]), - np.linalg.norm(points[1] - points[2]))) - pts_std = np.float32([[0, 0], [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height]]) - M = cv2.getPerspectiveTransform(points, pts_std) - dst_img = cv2.warpPerspective( - img, - M, (img_crop_width, img_crop_height), - borderMode=cv2.BORDER_REPLICATE, - flags=cv2.INTER_CUBIC) - dst_img_height, dst_img_width = dst_img.shape[0:2] - if dst_img_height * 1.0 / dst_img_width >= 1.5: - dst_img = np.rot90(dst_img) - return dst_img - def print_draw_crop_rec_res(self, img_crop_list, rec_res): bbox_num = len(img_crop_list) for bno in range(bbox_num): @@ -103,7 +69,7 @@ class TextSystem(object): for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) - img_crop = self.get_rotate_crop_image(ori_im, tmp_box) + img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( @@ -158,7 +124,7 @@ def main(args): img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(10): res = text_sys(img) - + total_time = 0 cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() @@ -175,12 +141,6 @@ def main(args): dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime total_time += elapse - if args.benchmark and idx % 20 == 0: - cm, gm, gu = get_current_memory_mb(0) - cpu_mem += cm - gpu_mem += gm - gpu_util += gu - count += 1 logger.info( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) @@ -215,61 +175,6 @@ def main(args): logger.info("\nThe predict total time is {}".format(total_time)) img_num = text_sys.text_detector.det_times.img_num - if args.benchmark: - mems = { - 'cpu_rss_mb': cpu_mem / count, - 'gpu_rss_mb': gpu_mem / count, - 'gpu_util': gpu_util * 100 / count - } - else: - mems = None - det_time_dict = text_sys.text_detector.det_times.report(average=True) - rec_time_dict = text_sys.text_recognizer.rec_times.report(average=True) - det_model_name = args.det_model_dir - rec_model_name = args.rec_model_dir - - # construct det log information - model_info = { - 'model_name': args.det_model_dir.split('/')[-1], - 'precision': args.precision - } - data_info = { - 'batch_size': 1, - 'shape': 'dynamic_shape', - 'data_num': det_time_dict['img_num'] - } - perf_info = { - 'preprocess_time_s': det_time_dict['preprocess_time'], - 'inference_time_s': det_time_dict['inference_time'], - 'postprocess_time_s': det_time_dict['postprocess_time'], - 'total_time_s': det_time_dict['total_time'] - } - - benchmark_log = benchmark_utils.PaddleInferBenchmark( - text_sys.text_detector.config, model_info, data_info, perf_info, mems, - args.save_log_path) - benchmark_log("Det") - - # construct rec log information - model_info = { - 'model_name': args.rec_model_dir.split('/')[-1], - 'precision': args.precision - } - data_info = { - 'batch_size': args.rec_batch_num, - 'shape': 'dynamic_shape', - 'data_num': rec_time_dict['img_num'] - } - perf_info = { - 'preprocess_time_s': rec_time_dict['preprocess_time'], - 'inference_time_s': rec_time_dict['inference_time'], - 'postprocess_time_s': rec_time_dict['postprocess_time'], - 'total_time_s': rec_time_dict['total_time'] - } - benchmark_log = benchmark_utils.PaddleInferBenchmark( - text_sys.text_recognizer.config, model_info, data_info, perf_info, mems, - args.save_log_path) - benchmark_log("Rec") if __name__ == "__main__": diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 90ac5aa5ba2a33707965159d0486bb42957e3fce..cf14e4abd71f1ac6e2ceec11163e635daef11f4d 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -37,6 +37,7 @@ def init_args(): parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--min_subgraph_size", type=int, default=3) parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) @@ -124,76 +125,6 @@ def parse_args(): return parser.parse_args() -class Times(object): - def __init__(self): - self.time = 0. - self.st = 0. - self.et = 0. - - def start(self): - self.st = time.time() - - def end(self, accumulative=True): - self.et = time.time() - if accumulative: - self.time += self.et - self.st - else: - self.time = self.et - self.st - - def reset(self): - self.time = 0. - self.st = 0. - self.et = 0. - - def value(self): - return round(self.time, 4) - - -class Timer(Times): - def __init__(self): - super(Timer, self).__init__() - self.total_time = Times() - self.preprocess_time = Times() - self.inference_time = Times() - self.postprocess_time = Times() - self.img_num = 0 - - def info(self, average=False): - logger.info("----------------------- Perf info -----------------------") - logger.info("total_time: {}, img_num: {}".format(self.total_time.value( - ), self.img_num)) - preprocess_time = round(self.preprocess_time.value() / self.img_num, - 4) if average else self.preprocess_time.value() - postprocess_time = round( - self.postprocess_time.value() / self.img_num, - 4) if average else self.postprocess_time.value() - inference_time = round(self.inference_time.value() / self.img_num, - 4) if average else self.inference_time.value() - - average_latency = self.total_time.value() / self.img_num - logger.info("average_latency(ms): {:.2f}, QPS: {:2f}".format( - average_latency * 1000, 1 / average_latency)) - logger.info( - "preprocess_latency(ms): {:.2f}, inference_latency(ms): {:.2f}, postprocess_latency(ms): {:.2f}". - format(preprocess_time * 1000, inference_time * 1000, - postprocess_time * 1000)) - - def report(self, average=False): - dic = {} - dic['preprocess_time'] = round( - self.preprocess_time.value() / self.img_num, - 4) if average else self.preprocess_time.value() - dic['postprocess_time'] = round( - self.postprocess_time.value() / self.img_num, - 4) if average else self.postprocess_time.value() - dic['inference_time'] = round( - self.inference_time.value() / self.img_num, - 4) if average else self.inference_time.value() - dic['img_num'] = self.img_num - dic['total_time'] = round(self.total_time.value(), 4) - return dic - - def create_predictor(args, mode, logger): if mode == "det": model_dir = args.det_model_dir @@ -212,11 +143,10 @@ def create_predictor(args, mode, logger): model_file_path = model_dir + "/inference.pdmodel" params_file_path = model_dir + "/inference.pdiparams" if not os.path.exists(model_file_path): - logger.info("not find model file path {}".format(model_file_path)) - sys.exit(0) + raise ValueError("not find model file path {}".format(model_file_path)) if not os.path.exists(params_file_path): - logger.info("not find params file path {}".format(params_file_path)) - sys.exit(0) + raise ValueError("not find params file path {}".format( + params_file_path)) config = inference.Config(model_file_path, params_file_path) @@ -236,14 +166,17 @@ def create_predictor(args, mode, logger): config.enable_tensorrt_engine( precision_mode=inference.PrecisionType.Float32, max_batch_size=args.max_batch_size, - min_subgraph_size=3) # skip the minmum trt subgraph - if mode == "det" and "mobile" in model_file_path: + min_subgraph_size=args.min_subgraph_size) + # skip the minmum trt subgraph + if mode == "det": min_input_shape = { "x": [1, 3, 50, 50], "conv2d_92.tmp_0": [1, 96, 20, 20], "conv2d_91.tmp_0": [1, 96, 10, 10], + "conv2d_59.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_1.tmp_0": [1, 96, 10, 10], "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], + "conv2d_124.tmp_0": [1, 96, 20, 20], "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20], @@ -254,7 +187,9 @@ def create_predictor(args, mode, logger): "x": [1, 3, 2000, 2000], "conv2d_92.tmp_0": [1, 96, 400, 400], "conv2d_91.tmp_0": [1, 96, 200, 200], + "conv2d_59.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_1.tmp_0": [1, 96, 200, 200], + "conv2d_124.tmp_0": [1, 256, 400, 400], "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], @@ -266,39 +201,16 @@ def create_predictor(args, mode, logger): "x": [1, 3, 640, 640], "conv2d_92.tmp_0": [1, 96, 160, 160], "conv2d_91.tmp_0": [1, 96, 80, 80], + "conv2d_59.tmp_0": [1, 96, 160, 160], "nearest_interp_v2_1.tmp_0": [1, 96, 80, 80], "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], + "conv2d_124.tmp_0": [1, 256, 160, 160], "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160], "elementwise_add_7": [1, 56, 40, 40], "nearest_interp_v2_0.tmp_0": [1, 96, 40, 40] } - if mode == "det" and "server" in model_file_path: - min_input_shape = { - "x": [1, 3, 50, 50], - "conv2d_59.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_2.tmp_0": [1, 96, 20, 20], - "nearest_interp_v2_3.tmp_0": [1, 24, 20, 20], - "nearest_interp_v2_4.tmp_0": [1, 24, 20, 20], - "nearest_interp_v2_5.tmp_0": [1, 24, 20, 20] - } - max_input_shape = { - "x": [1, 3, 2000, 2000], - "conv2d_59.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_2.tmp_0": [1, 96, 400, 400], - "nearest_interp_v2_3.tmp_0": [1, 24, 400, 400], - "nearest_interp_v2_4.tmp_0": [1, 24, 400, 400], - "nearest_interp_v2_5.tmp_0": [1, 24, 400, 400] - } - opt_input_shape = { - "x": [1, 3, 640, 640], - "conv2d_59.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_2.tmp_0": [1, 96, 160, 160], - "nearest_interp_v2_3.tmp_0": [1, 24, 160, 160], - "nearest_interp_v2_4.tmp_0": [1, 24, 160, 160], - "nearest_interp_v2_5.tmp_0": [1, 24, 160, 160] - } elif mode == "rec": min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]} max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]} @@ -328,11 +240,11 @@ def create_predictor(args, mode, logger): # enable memory optim config.enable_memory_optim() - config.disable_glog_info() + #config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") if mode == 'table': - config.delete_pass("fc_fuse_pass") # not supported for table + config.delete_pass("fc_fuse_pass") # not supported for table config.switch_use_feed_fetch_ops(False) config.switch_ir_optim(True) @@ -597,29 +509,39 @@ def draw_boxes(image, boxes, scores=None, drop_score=0.5): return image -def get_current_memory_mb(gpu_id=None): - """ - It is used to Obtain the memory usage of the CPU and GPU during the running of the program. - And this function Current program is time-consuming. - """ - import pynvml - import psutil - import GPUtil - pid = os.getpid() - p = psutil.Process(pid) - info = p.memory_full_info() - cpu_mem = info.uss / 1024. / 1024. - gpu_mem = 0 - gpu_percent = 0 - if gpu_id is not None: - GPUs = GPUtil.getGPUs() - gpu_load = GPUs[gpu_id].load - gpu_percent = gpu_load - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) - gpu_mem = meminfo.used / 1024. / 1024. - return round(cpu_mem, 4), round(gpu_mem, 4), round(gpu_percent, 4) +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img if __name__ == '__main__': diff --git a/tools/train.py b/tools/train.py index b024240b4d5d4973645336c62d3762087ec7bbeb..20f5a670d5c8e666678259e0042b3b790e528590 100755 --- a/tools/train.py +++ b/tools/train.py @@ -35,7 +35,7 @@ from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_dygraph_params import tools.program as program dist.get_world_size() @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, optimizer) + pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: