提交 95567abe 编写于 作者: L LDOUBLEV

TestReader.infer_img to Global.infer_img

上级 5e6fc91b
......@@ -15,7 +15,6 @@ EvalReader:
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.db_process,DBProcessTest
infer_img:
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval: True
......@@ -17,7 +17,6 @@ EvalReader:
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.east_process,EASTProcessTest
infer_img:
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval: True
......@@ -16,6 +16,7 @@ Global:
checkpoints:
save_res_path: ./output/det_db/predicts_db.txt
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -16,6 +16,7 @@ Global:
checkpoints:
save_res_path: ./output/det_db/predicts_db.txt
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
checkpoints:
save_res_path: ./output/det_east/predicts_east.txt
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
save_res_path: ./output/det_r18_vd_db/predicts_db.txt
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
save_res_path: ./output/det_db/predicts_db.txt
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
save_res_path: ./output/det_east/predicts_east.txt
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
save_res_path: ./output/det_sast/predicts_sast.txt
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -15,6 +15,7 @@ Global:
save_res_path: ./output/det_sast/predicts_sast.txt
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
......
......@@ -20,5 +20,4 @@ EvalReader:
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg
max_side_len: 1536
......@@ -20,5 +20,4 @@ EvalReader:
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg
max_side_len: 768
......@@ -17,7 +17,7 @@ wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_la
PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支持的数据格式。 数据转换工具在 `train_data/gen_label.py`, 这里以训练集为例:
```
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
python gen_label.py --mode="det" --root_path="icdar_c4_train_imgs/" \
--input_path="ch4_training_localization_transcription_gt" \
--output_label="train_icdar2015_label.txt"
......@@ -74,7 +74,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
```shell
# 训练 mv3_db 模型,并将训练日志保存为 tain_det.log
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml \
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml \
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \
2>&1 | tee train_det.log
```
......@@ -119,16 +119,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints=
测试单张图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
测试DB模型时,调整后处理阈值,
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```
......@@ -27,7 +27,7 @@ The provided annotation file format is as follow, seperated by "\t":
" Image file name Image annotation information encoded by json.dumps"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
```
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner.
......@@ -110,16 +110,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints=
Test the detection result on a single image:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
When testing the DB model, adjust the post-processing threshold:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
Test the detection result on all images in the folder:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o Global.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```
......@@ -294,7 +294,7 @@ The default prediction picture is stored in `infer_img`, and the weight is speci
```
# Predict English results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.jpg
```
Input image:
......@@ -313,7 +313,7 @@ The configuration file used for prediction must be consistent with the training.
```
# Predict Chinese results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/ch/word_1.jpg
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
Input image:
......
......@@ -104,8 +104,8 @@ def main():
save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
with open(save_res_path, "wb") as fout:
with open(save_res_path, "wb") as fout:
test_reader = reader_main(config=config, mode='test')
tackling_num = 0
for data in test_reader():
......@@ -135,9 +135,15 @@ def main():
elif config['Global']['algorithm'] == 'DB':
dic = {'maps': outs[0]}
elif config['Global']['algorithm'] == 'SAST':
dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
dic = {
'f_score': outs[0],
'f_border': outs[1],
'f_tvo': outs[2],
'f_tco': outs[3]
}
else:
raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
raise Exception(
"only support algorithm: ['EAST', 'DB', 'SAST']")
dt_boxes_list = postprocess(dic, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
......@@ -151,7 +157,7 @@ def main():
fout.write(otstr.encode())
src_img = cv2.imread(img_name)
draw_det_res(dt_boxes, config, src_img, img_name)
logger.info("success!")
......
......@@ -121,7 +121,10 @@ def merge_config(config):
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur)
assert (
sub_key in cur
), "key {} not in sub_keys: {}, please check your running command.".format(
sub_key, cur)
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册