提交 268ed03b 编写于 作者: L LDOUBLEV

update detection doc and infer code

上级 af399022
......@@ -32,8 +32,8 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
## 3.2 快速启动训练
首先下载pretrain model,目前支持两种backbone,分别是MobileNetV3、ResNet50,您可以根据需求使用PaddleClas中的模型更换
backbone。
首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```
# 下载MobileNetV3的预训练模型
wget -P /PaddleOCR/pretrained_model/ 模型链接
......@@ -63,7 +63,17 @@ PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall
运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。
```
python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints ./output/best_accuracy
python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints="./output/best_accuracy"
```
## 3.4 测试检测效果
测试单张图像的检测效果
```
python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo.jpg"
```
测试文件夹下所有图像的检测效果
```
python3 tools/infer_det.py -c config/det/det_db_mv3.yml -o TestReader.single_img_path="./demo_img/"
```
......@@ -124,9 +124,6 @@ class DBProcessTest(object):
def resize_image_type0(self, im):
"""
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
max_side_len = self.max_side_len
h, w, _ = im.shape
......
......@@ -73,9 +73,3 @@ def reader_main(config=None, mode=None):
return paddle.reader.multiprocess_reader(readers, False)
else:
return function(mode)
def test_reader(image_shape, img_path):
img = cv2.imread(img_path)
norm_img = process_image(img, image_shape)
return norm_img
......@@ -3,6 +3,10 @@
from collections import namedtuple
import numpy as np
from shapely.geometry import Polygon
"""
reference from :
https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
"""
class DetectionIoUEvaluator(object):
......
......@@ -98,6 +98,14 @@ def load_label_infor(label_file_path, do_ignore=False):
def cal_det_metrics(gt_label_path, save_res_path):
"""
calculate the detection metrics
Args:
gt_label_path(string): The groundtruth detection label file path
save_res_path(string): The saved predicted detection label path
return:
claculated metrics including Hmean、precision and recall
"""
evaluator = DetectionIoUEvaluator()
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
dt_label_infor = load_label_infor(save_res_path)
......
......@@ -99,15 +99,7 @@ def create_predictor(args, mode):
config.disable_gpu()
config.disable_glog_info()
# config.switch_ir_optim(args.ir_optim)
# if args.use_tensorrt:
# config.enable_tensorrt_engine(
# precision_mode=AnalysisConfig.Precision.Half
# if args.use_fp16 else AnalysisConfig.Precision.Float32,
# max_batch_size=args.batch_size)
# config.enable_memory_optim()
# use zero copy
config.switch_use_feed_fetch_ops(False)
predictor = create_paddle_predictor(config)
......
......@@ -44,6 +44,7 @@ from ppocr.utils.utility import create_module
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
import cv2
from ppocr.utils.utility import initial_logger
logger = initial_logger()
......@@ -67,6 +68,50 @@ def draw_det_res(dt_boxes, config, img_name, ino):
logger.info("The detected Image saved in {}".format(save_path))
def simple_reader(img_file, config):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(img_file, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
batch_size = config['Global']['test_batch_size_per_card']
global_params = config['Global']
params = deepcopy(config['TestReader'])
params.update(global_params)
reader_function = params['process_function']
process_function = create_module(reader_function)(params)
def batch_iter_reader():
batch_outs = []
for img_path in imgs_lists:
img = cv2.imread(img_path)
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img is None:
logger.info("load image error:" + img_path)
continue
outs = process_function(img)
outs.append(os.path.basename(img_path))
print(outs[0].shape, outs[2])
batch_outs.append(outs)
if len(batch_outs) == batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
def main():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
......@@ -103,7 +148,9 @@ def main():
save_res_path = config['Global']['save_res_path']
with open(save_res_path, "wb") as fout:
test_reader = reader_main(config=config, mode='test')
# test_reader = reader_main(config=config, mode='test')
single_img_path = config['TestReader']['single_img_path']
test_reader = simple_reader(img_file=single_img_path, config=config)
tackling_num = 0
for data in test_reader():
img_num = len(data)
......@@ -116,6 +163,7 @@ def main():
img_list.append(data[ino][0])
ratio_list.append(data[ino][1])
img_name_list.append(data[ino][2])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_prog,\
feed={'image': img_list},\
......@@ -126,7 +174,7 @@ def main():
postprocess_params.update(global_params)
postprocess = create_module(postprocess_params['function'])\
(params=postprocess_params)
dt_boxes_list = postprocess(outs, ratio_list)
dt_boxes_list = postprocess({"maps": outs[0]}, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
img_name = img_name_list[ino]
......
......@@ -185,22 +185,6 @@ def build(config, main_prog, startup_prog, mode):
def build_export(config, main_prog, startup_prog):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册