提交 3f2d384f 编写于 作者: L LDOUBLEV

add doc、infer_det.py、requirments.txt

上级 e3388a24
...@@ -12,6 +12,7 @@ Global: ...@@ -12,6 +12,7 @@ Global:
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
checkpoints:
save_res_path: ./output/predicts_db.txt save_res_path: ./output/predicts_db.txt
Architecture: Architecture:
......
# 文字检测
本节以icdar15数据集为例,介绍PaddleOCR中检测模型的使用方式。
## 3.1 数据准备
icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。
将下载到的数据集解压到工作目录下,假设解压在/PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件
,您可以通过wget的方式进行下载。
```
wget -P /PaddleOCR/train_data/ 训练标注文件链接
wget -P /PaddleOCR/train_data/ 测试标注文件链接
```
解压数据集和下载标注文件后,/PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
```
/PaddleOCR/train_data/
└─ icdar_c4_train_imgs/ icdar数据集的训练数据
└─ ch4_test_images/ icdar数据集的测试数据
└─ train_icdar2015_label.txt icdar数据集的训练标注
└─ test_icdar2015_label.txt icdar数据集的测试标注
```
提供的标注文件格式为:
```
" 图像文件名 json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
```
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的points表示文本框的位置,如果您想在其他数据集上训练PaddleOCR,
可以按照上述形式构建标注文件。
## 3.2 快速启动训练
首先下载pretrain model,目前支持两种backbone,分别是MobileNetV3、ResNet50,您可以根据需求使用PaddleClas中的模型更换
backbone。
```
# 下载MobileNetV3的预训练模型
wget -P /PaddleOCR/pretrained_model/ 模型链接
# 下载ResNet50的预训练模型
wget -P /PaddleOCR/pretrained_model/ 模型链接
```
**启动训练**
```
cd PaddleOCR/
python3 tools/train.py -c configs/det/det_db_mv3.yml
```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
有关配置文件的详细解释,请参考[链接]()。
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```
python3 tools/train.py -c configs/det/det_db_mv3.yml -o Optimizer.base_lr=0.0001
```
## 3.3 指标评估
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。
```
python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints ./output/best_accuracy
```
## 3.4 测试检测效果
### 2.1 快速安装
我们提供了PaddleOCR开发环境的docker,您可以pull我们提供的docker运行PaddleOCR的环境。
1. 准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
```
# 切换到工作目录下
cd /home/Projects
# 创建一个名字为pdocr的docker容器,并将当前目录映射到容器的/data目录下
sudo nvidia-docker run --name pdocr -v $PWD:/data --network=host -it paddlepaddle/paddle:1.7.2-gpu-cuda10.0-cudnn7 /bin/bash
```
2. 克隆PaddleOCR repo代码
```
apt-get update
apt-get install git
git clone https://github.com/PaddlePaddle/PaddleOCR
```
3. 安装第三方库
```
cd PaddleOCR
pip3 install --upgrade pip
pip3 install -r requirements.txt
```
shapely
imgaug
pyclipper
\ No newline at end of file
...@@ -64,15 +64,39 @@ class TextSystem(object): ...@@ -64,15 +64,39 @@ class TextSystem(object):
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
for bno in range(len(dt_boxes)): for bno in range(len(dt_boxes)):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res return dt_boxes, rec_res
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1])
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp
return _boxes
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = utility.parse_args()
image_file_list = utility.get_image_file_list(args.image_dir) image_file_list = utility.get_image_file_list(args.image_dir)
......
...@@ -106,7 +106,7 @@ def create_predictor(args, mode): ...@@ -106,7 +106,7 @@ def create_predictor(args, mode):
# if args.use_fp16 else AnalysisConfig.Precision.Float32, # if args.use_fp16 else AnalysisConfig.Precision.Float32,
# max_batch_size=args.batch_size) # max_batch_size=args.batch_size)
config.enable_memory_optim() # config.enable_memory_optim()
# use zero copy # use zero copy
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
...@@ -136,12 +136,16 @@ if __name__ == '__main__': ...@@ -136,12 +136,16 @@ if __name__ == '__main__':
args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db" args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db"
predictor, input_tensor, output_tensors = create_predictor(args, mode='det') predictor, input_tensor, output_tensors = create_predictor(args, mode='det')
print(predictor.get_input_names()) print("det input", predictor.get_input_names())
print(predictor.get_output_names()) print("det output", predictor.get_output_names())
print(predictor.program(), file=open("det_program.txt", 'w')) # print(predictor.program(), file=open("det_program.txt", 'w'))
outputs = []
for output_tensor in output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/" args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/"
rec_predictor, input_tensor, output_tensors = create_predictor( rec_predictor, input_tensor, output_tensors = create_predictor(
args, mode='rec') args, mode='rec')
print(rec_predictor.get_input_names()) print("rec input", rec_predictor.get_input_names())
print(rec_predictor.get_output_names()) print("rec output", rec_predictor.get_output_names())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import numpy as np
from copy import deepcopy
import json
# from paddle.fluid.contrib.model_stat import summary
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
from paddle import fluid
from ppocr.utils.utility import create_module
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
from ppocr.utils.utility import initial_logger
logger = initial_logger()
def draw_det_res(dt_boxes, config, img_name, ino):
if len(dt_boxes) > 0:
img_set_path = config['TestReader']['img_set_dir']
img_path = img_set_path + img_name
import cv2
src_im = cv2.imread(img_path)
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
save_det_path = os.path.basename(config['Global'][
'save_res_path']) + "/det_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
def main():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
program.check_gpu(use_gpu)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
det_model = create_module(config['Architecture']['function'])(params=config)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
_, eval_outputs = det_model(mode="test")
fetch_name_list = list(eval_outputs.keys())
eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]
eval_prog = eval_prog.clone(for_test=True)
exe.run(startup_prog)
# load checkpoints
checkpoints = config['Global'].get('checkpoints')
if checkpoints:
path = checkpoints
fluid.load(eval_prog, path, exe)
logger.info("Finish initing model from {}".format(path))
else:
raise Exception("{} not exists!".format(checkpoints))
save_res_path = config['Global']['save_res_path']
with open(save_res_path, "wb") as fout:
test_reader = reader_main(config=config, mode='test')
tackling_num = 0
for data in test_reader():
img_num = len(data)
tackling_num = tackling_num + img_num
logger.info("tackling_num:%d", tackling_num)
img_list = []
ratio_list = []
img_name_list = []
for ino in range(img_num):
img_list.append(data[ino][0])
ratio_list.append(data[ino][1])
img_name_list.append(data[ino][2])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_prog,\
feed={'image': img_list},\
fetch_list=eval_fetch_list)
global_params = config['Global']
postprocess_params = deepcopy(config["PostProcess"])
postprocess_params.update(global_params)
postprocess = create_module(postprocess_params['function'])\
(params=postprocess_params)
dt_boxes_list = postprocess(outs, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
img_name = img_name_list[ino]
dt_boxes_json = []
for box in dt_boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
draw_det_res(dt_boxes, config, img_name, ino)
logger.info("success!")
if __name__ == '__main__':
parser = program.ArgsParser()
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册