From 3f2d384faa95a1ad90b0475494f0624f2b900670 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Mon, 11 May 2020 15:27:52 +0800 Subject: [PATCH] =?UTF-8?q?add=20doc=E3=80=81infer=5Fdet.py=E3=80=81requir?= =?UTF-8?q?ments.txt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/det/det_db_mv3.yml | 1 + doc/detection.md | 69 ++++++++++++++++ doc/installation.md | 25 ++++++ requirments.txt | 3 + tools/infer/predict_system.py | 26 +++++- tools/infer/utility.py | 16 ++-- tools/infer_det.py | 148 ++++++++++++++++++++++++++++++++++ 7 files changed, 281 insertions(+), 7 deletions(-) create mode 100644 doc/detection.md create mode 100644 doc/installation.md create mode 100644 requirments.txt create mode 100755 tools/infer_det.py diff --git a/configs/det/det_db_mv3.yml b/configs/det/det_db_mv3.yml index a41c901e..197b1204 100755 --- a/configs/det/det_db_mv3.yml +++ b/configs/det/det_db_mv3.yml @@ -12,6 +12,7 @@ Global: image_shape: [3, 640, 640] reader_yml: ./configs/det/det_db_icdar15_reader.yml pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ + checkpoints: save_res_path: ./output/predicts_db.txt Architecture: diff --git a/doc/detection.md b/doc/detection.md new file mode 100644 index 00000000..b8b8a9cc --- /dev/null +++ b/doc/detection.md @@ -0,0 +1,69 @@ +# 文字检测 + +本节以icdar15数据集为例,介绍PaddleOCR中检测模型的使用方式。 + +## 3.1 数据准备 +icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。 + +将下载到的数据集解压到工作目录下,假设解压在/PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件 +,您可以通过wget的方式进行下载。 +``` +wget -P /PaddleOCR/train_data/ 训练标注文件链接 +wget -P /PaddleOCR/train_data/ 测试标注文件链接 +``` + +解压数据集和下载标注文件后,/PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是: +``` +/PaddleOCR/train_data/ + └─ icdar_c4_train_imgs/ icdar数据集的训练数据 + └─ ch4_test_images/ icdar数据集的测试数据 + └─ train_icdar2015_label.txt icdar数据集的训练标注 + └─ test_icdar2015_label.txt icdar数据集的测试标注 +``` + +提供的标注文件格式为: +``` +" 图像文件名 json.dumps编码的图像标注信息" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}] +``` +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的points表示文本框的位置,如果您想在其他数据集上训练PaddleOCR, +可以按照上述形式构建标注文件。 + + +## 3.2 快速启动训练 + +首先下载pretrain model,目前支持两种backbone,分别是MobileNetV3、ResNet50,您可以根据需求使用PaddleClas中的模型更换 +backbone。 +``` +# 下载MobileNetV3的预训练模型 +wget -P /PaddleOCR/pretrained_model/ 模型链接 +# 下载ResNet50的预训练模型 +wget -P /PaddleOCR/pretrained_model/ 模型链接 +``` + +**启动训练** +``` +cd PaddleOCR/ +python3 tools/train.py -c configs/det/det_db_mv3.yml +``` + +上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 +有关配置文件的详细解释,请参考[链接]()。 + +您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 +``` +python3 tools/train.py -c configs/det/det_db_mv3.yml -o Optimizer.base_lr=0.0001 +``` + + +## 3.3 指标评估 + +PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。 + +运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。 + +``` +python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints ./output/best_accuracy +``` + +## 3.4 测试检测效果 diff --git a/doc/installation.md b/doc/installation.md new file mode 100644 index 00000000..5eca1dfe --- /dev/null +++ b/doc/installation.md @@ -0,0 +1,25 @@ +### 2.1 快速安装 + +我们提供了PaddleOCR开发环境的docker,您可以pull我们提供的docker运行PaddleOCR的环境。 + +1. 准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。 +``` +# 切换到工作目录下 +cd /home/Projects +# 创建一个名字为pdocr的docker容器,并将当前目录映射到容器的/data目录下 +sudo nvidia-docker run --name pdocr -v $PWD:/data --network=host -it paddlepaddle/paddle:1.7.2-gpu-cuda10.0-cudnn7 /bin/bash +``` + +2. 克隆PaddleOCR repo代码 +``` +apt-get update +apt-get install git +git clone https://github.com/PaddlePaddle/PaddleOCR +``` + +3. 安装第三方库 +``` +cd PaddleOCR +pip3 install --upgrade pip +pip3 install -r requirements.txt +``` diff --git a/requirments.txt b/requirments.txt new file mode 100644 index 00000000..e9bf32a9 --- /dev/null +++ b/requirments.txt @@ -0,0 +1,3 @@ +shapely +imgaug +pyclipper \ No newline at end of file diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index defa0615..4907a7cc 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -64,15 +64,39 @@ class TextSystem(object): if dt_boxes is None: return None, None img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) rec_res, elapse = self.text_recognizer(img_crop_list) - # self.print_draw_crop_rec_res(img_crop_list, rec_res) + # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res +def sorted_boxes(dt_boxes): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = dt_boxes.shape[0] + sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1]) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): + tmp = _boxes[i] + _boxes[i] = _boxes[i + 1] + _boxes[i + 1] = tmp + return _boxes + + if __name__ == "__main__": args = utility.parse_args() image_file_list = utility.get_image_file_list(args.image_dir) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a4f9f03d..ff9586ab 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -106,7 +106,7 @@ def create_predictor(args, mode): # if args.use_fp16 else AnalysisConfig.Precision.Float32, # max_batch_size=args.batch_size) - config.enable_memory_optim() + # config.enable_memory_optim() # use zero copy config.switch_use_feed_fetch_ops(False) predictor = create_paddle_predictor(config) @@ -136,12 +136,16 @@ if __name__ == '__main__': args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db" predictor, input_tensor, output_tensors = create_predictor(args, mode='det') - print(predictor.get_input_names()) - print(predictor.get_output_names()) - print(predictor.program(), file=open("det_program.txt", 'w')) + print("det input", predictor.get_input_names()) + print("det output", predictor.get_output_names()) + # print(predictor.program(), file=open("det_program.txt", 'w')) + outputs = [] + for output_tensor in output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/" rec_predictor, input_tensor, output_tensors = create_predictor( args, mode='rec') - print(rec_predictor.get_input_names()) - print(rec_predictor.get_output_names()) + print("rec input", rec_predictor.get_input_names()) + print("rec output", rec_predictor.get_output_names()) diff --git a/tools/infer_det.py b/tools/infer_det.py new file mode 100755 index 00000000..d616323d --- /dev/null +++ b/tools/infer_det.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import time +import numpy as np +from copy import deepcopy +import json + +# from paddle.fluid.contrib.model_stat import summary + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +from paddle import fluid +from ppocr.utils.utility import create_module +import program +from ppocr.utils.save_load import init_model +from ppocr.data.reader_main import reader_main + +from ppocr.utils.utility import initial_logger +logger = initial_logger() + + +def draw_det_res(dt_boxes, config, img_name, ino): + if len(dt_boxes) > 0: + img_set_path = config['TestReader']['img_set_dir'] + img_path = img_set_path + img_name + import cv2 + src_im = cv2.imread(img_path) + for box in dt_boxes: + box = box.astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) + save_det_path = os.path.basename(config['Global'][ + 'save_res_path']) + "/det_results/" + if not os.path.exists(save_det_path): + os.makedirs(save_det_path) + save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name)) + cv2.imwrite(save_path, src_im) + logger.info("The detected Image saved in {}".format(save_path)) + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + print(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + program.check_gpu(use_gpu) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + det_model = create_module(config['Architecture']['function'])(params=config) + + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + _, eval_outputs = det_model(mode="test") + fetch_name_list = list(eval_outputs.keys()) + eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list] + + eval_prog = eval_prog.clone(for_test=True) + exe.run(startup_prog) + + # load checkpoints + checkpoints = config['Global'].get('checkpoints') + if checkpoints: + path = checkpoints + fluid.load(eval_prog, path, exe) + logger.info("Finish initing model from {}".format(path)) + else: + raise Exception("{} not exists!".format(checkpoints)) + + save_res_path = config['Global']['save_res_path'] + with open(save_res_path, "wb") as fout: + test_reader = reader_main(config=config, mode='test') + tackling_num = 0 + for data in test_reader(): + img_num = len(data) + tackling_num = tackling_num + img_num + logger.info("tackling_num:%d", tackling_num) + img_list = [] + ratio_list = [] + img_name_list = [] + for ino in range(img_num): + img_list.append(data[ino][0]) + ratio_list.append(data[ino][1]) + img_name_list.append(data[ino][2]) + img_list = np.concatenate(img_list, axis=0) + outs = exe.run(eval_prog,\ + feed={'image': img_list},\ + fetch_list=eval_fetch_list) + + global_params = config['Global'] + postprocess_params = deepcopy(config["PostProcess"]) + postprocess_params.update(global_params) + postprocess = create_module(postprocess_params['function'])\ + (params=postprocess_params) + dt_boxes_list = postprocess(outs, ratio_list) + for ino in range(img_num): + dt_boxes = dt_boxes_list[ino] + img_name = img_name_list[ino] + dt_boxes_json = [] + for box in dt_boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_json.append(tmp_json) + otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n" + fout.write(otstr.encode()) + draw_det_res(dt_boxes, config, img_name, ino) + + logger.info("success!") + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() -- GitLab