未验证 提交 a9e7c8f9 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #793 from MissPenguin/develop

update detection.md for v1.1
......@@ -44,13 +44,15 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
## 快速启动训练
首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd
首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```shell
cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
# 下载ResNet50的预训练模型
# 或,下载ResNet18_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
# 或,下载ResNet50_vd的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
# 解压预训练模型文件,以MobileNetV3为例
......@@ -72,24 +74,24 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
```shell
# 训练 mv3_db 模型,并将训练日志保存为 tain_det.log
python3 tools/train.py -c configs/det/det_mv3_db.yml \
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml \
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \
2>&1 | tee train_det.log
```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3_v1.1.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml -o Optimizer.base_lr=0.0001
```
#### 断点训练
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints=./your/trained/model
```
**注意**`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
......@@ -98,17 +100,17 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./you
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
运行如下代码,根据配置文件`det_db_mv3.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
运行如下代码,根据配置文件`det_db_mv3_v1.1.yml``save_res_path`指定的测试集检测结果文件,计算评估指标。
评估时设置后处理参数`box_thresh=0.6``unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
比如:
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
* 注:`box_thresh``unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
......@@ -117,16 +119,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./ou
测试单张图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
测试DB模型时,调整后处理阈值,
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```
......@@ -38,12 +38,14 @@ If you want to train PaddleOCR on other datasets, please build the annotation fi
## TRAINING
First download the pretrained model. The detection model of PaddleOCR currently supports two backbones, namely MobileNetV3 and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures) to replace backbone according to your needs.
First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures) to replace backbone according to your needs.
```shell
cd PaddleOCR/
# Download the pre-trained model of MobileNetV3
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
# Download the pre-trained model of ResNet50
# or, download the pre-trained model of ResNet18_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet18_vd_pretrained.tar
# or, download the pre-trained model of ResNet50_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
# decompressing the pre-training model file, take MobileNetV3 as an example
......@@ -62,7 +64,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
#### START TRAINING
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml 2>&1 | tee train_det.log
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml 2>&1 | tee train_det.log
```
In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
......@@ -70,7 +72,7 @@ For a detailed explanation of the configuration file, please refer to [config](.
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml -o Optimizer.base_lr=0.0001
```
#### load trained model and continue training
......@@ -78,7 +80,7 @@ If you expect to load trained model and continue the training again, you can spe
For example:
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
python3 tools/train.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints=./your/trained/model
```
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
......@@ -88,18 +90,18 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./you
PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean.
Run the following code to calculate the evaluation indicators. The result will be saved in the test result file specified by `save_res_path` in the configuration file `det_db_mv3.yml`
Run the following code to calculate the evaluation indicators. The result will be saved in the test result file specified by `save_res_path` in the configuration file `det_db_mv3_v1.1.yml`
When evaluating, set post-processing parameters `box_thresh=0.6`, `unclip_ratio=1.5`. If you use different datasets, different models for training, these two parameters should be adjusted for better result.
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
Such as:
```shell
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/eval.py -c configs/det/det_mv3_db_v1.1.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST model.
......@@ -108,16 +110,16 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./ou
Test the detection result on a single image:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
When testing the DB model, adjust the post-processing threshold:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
Test the detection result on all images in the folder:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
python3 tools/infer_det.py -c configs/det/det_mv3_db_v1.1.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册