未验证 提交 f1a41079 编写于 作者: F Feng Ni 提交者: GitHub

add YOLOX doc and configs (#5741)

* add yolox doc and configs, test=document_fix

* fix bn and update doc, test=document_fix
上级 10bf8de7
# YOLOX (YOLOX: Exceeding YOLO Series in 2021)
## Model Zoo
### YOLOX on COCO
| 网络网络 | 输入尺寸 | 图片数/GPU | 学习率策略 |推理时间(fps) | Box AP | 下载链接 | 配置文件 |
| :------------- | :------- | :-------: | :------: | :---------: | :-----: | :-------------: | :-----: |
| YOLOX-nano | 416 | 8 | 300e | ---- | 26.1 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_nano_300e_coco.pdparams) | [配置文件](./yolox_nano_300e_coco.yml) |
| YOLOX-tiny | 416 | 8 | 300e | ---- | 32.9 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_tiny_300e_coco.pdparams) | [配置文件](./yolox_tiny_300e_coco.yml) |
| YOLOX-s | 640 | 8 | 300e | ---- | 40.4 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams) | [配置文件](./yolox_s_300e_coco.yml) |
| YOLOX-m | 640 | 8 | 300e | ---- | 46.9 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_m_300e_coco.pdparams) | [配置文件](./yolox_m_300e_coco.yml) |
| YOLOX-l | 640 | 8 | 300e | ---- | 50.1 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_l_300e_coco.pdparams) | [配置文件](./yolox_l_300e_coco.yml) |
| YOLOX-x | 640 | 8 | 300e | ---- | 51.4 | [下载链接](https://paddledet.bj.bcebos.com/models/yolox_x_300e_coco.pdparams) | [配置文件](./yolox_x_300e_coco.yml) |
**注意:**
- 以上模型默认采用8 GPUs训练,总batch_size为64,均训练300 epochs;
- 为保持高mAP的同时提高推理速度,可以将[yolox_cspdarknet.yml](_base_/yolox_cspdarknet.yml)中的`nms_top_k`修改为`1000`,将`keep_top_k`修改为`100`,mAP会下降约0.1~0.2%;
- 为快速的demo演示效果,可以将[yolox_cspdarknet.yml](_base_/yolox_cspdarknet.yml)中的`score_threshold`修改为`0.25`,将`nms_threshold`修改为`0.45`,但mAP会下降较多;
## 使用教程
### 1. 训练
执行以下指令使用混合精度训练YOLOX
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/yolox/yolox_s_300e_coco.yml --fleet --amp --eval
```
**注意:**
使用默认配置训练需要设置`--fleet``--amp`最好也设置以避免显存溢出,`--eval`表示边训边验证。
### 2. 评估
执行以下命令在单个GPU上评估COCO val2017数据集
```bash
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/yolox/yolox_s_300e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
```
### 3. 推理
使用以下命令在单张GPU上预测图片,使用`--infer_img`推理单张图片以及使用`--infer_dir`推理文件中的所有图片。
```bash
# 推理单张图片
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/yolox/yolox_s_300e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams --infer_img=demo/000000014439_640x640.jpg
# 推理文件中的所有图片
CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/yolox/yolox_s_300e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams --infer_dir=demo
```
### 4. 部署
YOLOX在GPU上推理部署或benchmark测速等需要通过`tools/export_model.py`导出模型。
运行以下的命令进行导出:
```bash
python tools/export_model.py -c configs/yolox/yolox_s_300e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/yolox_s_300e_coco.pdparams
```
`deploy/python/infer.py`使用上述导出后的Paddle Inference模型用于推理和benchnark测速.
```bash
# 推理单张图片
python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --image_file=demo/000000014439_640x640.jpg --device=gpu
# 推理文件夹下的所有图片
python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --image_dir=demo/ --device=gpu
# benchmark测速
python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --image_file=demo/000000014439_640x640.jpg --device=gpu --run_benchmark=True
# tensorRT-FP32测速
python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --image_file=demo/000000014439_640x640.jpg --device=gpu --run_benchmark=True --trt_max_shape=640 --trt_min_shape=640 --trt_opt_shape=640 --run_mode=trt_fp32
# tensorRT-FP16测速
python deploy/python/infer.py --model_dir=output_inference/yolox_s_300e_coco --image_file=demo/000000014439_640x640.jpg --device=gpu --run_benchmark=True --trt_max_shape=640 --trt_min_shape=640 --trt_opt_shape=640 --run_mode=trt_fp16
```
## Citations
```
@article{yolox2021,
title={YOLOX: Exceeding YOLO Series in 2021},
author={Ge, Zheng and Liu, Songtao and Wang, Feng and Li, Zeming and Sun, Jian},
journal={arXiv preprint arXiv:2107.08430},
year={2021}
}
```
epoch: 300
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 300
min_lr_ratio: 0.05
last_plateau_epochs: 15
- !ExpWarmup
epochs: 5
OptimizerBuilder:
optimizer:
type: Momentum
momentum: 0.9
use_nesterov: True
regularizer:
factor: 0.0005
type: L2
architecture: YOLOX
norm_type: sync_bn
use_ema: True
ema_decay: 0.9999
ema_decay_type: "exponential"
act: silu
find_unused_parameters: True
depth_mult: 1.0
width_mult: 1.0
YOLOX:
backbone: CSPDarkNet
neck: YOLOCSPPAN
head: YOLOXHead
size_stride: 32
size_range: [15, 25] # multi-scale range [480*480 ~ 800*800]
CSPDarkNet:
arch: "X"
return_idx: [2, 3, 4]
depthwise: False
YOLOCSPPAN:
depthwise: False
YOLOXHead:
l1_epoch: 285
depthwise: False
loss_weight: {cls: 1.0, obj: 1.0, iou: 5.0, l1: 1.0}
assigner:
name: SimOTAAssigner
candidate_topk: 10
use_vfl: False
nms:
name: MultiClassNMS
nms_top_k: 10000
keep_top_k: 1000
score_threshold: 0.001
nms_threshold: 0.65
# For speed while keep high mAP, you can modify 'nms_top_k' to 1000 and 'keep_top_k' to 100, the mAP will drop about 0.1%.
# For high speed demo, you can modify 'score_threshold' to 0.25 and 'nms_threshold' to 0.45, but the mAP will drop a lot.
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- Mosaic:
prob: 1.0
input_dim: [640, 640]
degrees: [-10, 10]
scale: [0.1, 2.0]
shear: [-2, 2]
translate: [-0.1, 0.1]
enable_mixup: True
mixup_prob: 1.0
mixup_scale: [0.5, 1.5]
- AugmentHSV: {is_bgr: False, hgain: 5, sgain: 30, vgain: 30}
- PadResize: {target_size: 640}
- RandomFlip: {}
batch_transforms:
- Permute: {}
batch_size: 8
shuffle: True
drop_last: True
collate_batch: False
mosaic_epoch: 285
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: 640, keep_ratio: True}
- Pad: {size: 640, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 4
TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: 640, keep_ratio: True}
- Pad: {size: 640, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 1.0
width_mult: 1.0
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_l_300e_coco/model_final
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 0.67
width_mult: 0.75
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_m_300e_coco/model_final
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 0.33
width_mult: 0.25
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_nano_300e_coco/model_final
### model config:
# Note: YOLOX-nano use depthwise conv in backbone, neck and head.
YOLOX:
backbone: CSPDarkNet
neck: YOLOCSPPAN
head: YOLOXHead
size_stride: 32
size_range: [10, 20] # multi-scale range [320*320 ~ 640*640]
CSPDarkNet:
arch: "X"
return_idx: [2, 3, 4]
depthwise: True
YOLOCSPPAN:
depthwise: True
YOLOXHead:
depthwise: True
### reader config:
# Note: YOLOX-tiny/nano uses 416*416 for evaluation and inference.
# And multi-scale training setting is in model config, TrainReader's operators use 640*640 as default.
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- Mosaic:
prob: 0.5 # 1.0 in YOLOX-tiny/s/m/l/x
input_dim: [640, 640]
degrees: [-10, 10]
scale: [0.5, 1.5] # [0.1, 2.0] in YOLOX-s/m/l/x
shear: [-2, 2]
translate: [-0.1, 0.1]
enable_mixup: False # True in YOLOX-s/m/l/x
- AugmentHSV: {is_bgr: False, hgain: 5, sgain: 30, vgain: 30}
- PadResize: {target_size: 640}
- RandomFlip: {}
batch_transforms:
- Permute: {}
batch_size: 8
shuffle: True
drop_last: True
collate_batch: False
mosaic_epoch: 285
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: 416, keep_ratio: True}
- Pad: {size: 416, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 8
TestReader:
inputs_def:
image_shape: [3, 416, 416]
sample_transforms:
- Decode: {}
- Resize: {target_size: 416, keep_ratio: True}
- Pad: {size: 416, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 0.33
width_mult: 0.50
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_s_300e_coco/model_final
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 0.33
width_mult: 0.375
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_tiny_300e_coco/model_final
### model config:
YOLOX:
backbone: CSPDarkNet
neck: YOLOCSPPAN
head: YOLOXHead
size_stride: 32
size_range: [10, 20] # multi-scale ragne [320*320 ~ 640*640]
### reader config:
# Note: YOLOX-tiny/nano uses 416*416 for evaluation and inference.
# And multi-scale training setting is in model config, TrainReader's operators use 640*640 as default.
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- Mosaic:
prob: 1.0
input_dim: [640, 640]
degrees: [-10, 10]
scale: [0.5, 1.5] # [0.1, 2.0] in YOLOX-s/m/l/x
shear: [-2, 2]
translate: [-0.1, 0.1]
enable_mixup: False # True in YOLOX-s/m/l/x
- AugmentHSV: {is_bgr: False, hgain: 5, sgain: 30, vgain: 30}
- PadResize: {target_size: 640}
- RandomFlip: {}
batch_transforms:
- Permute: {}
batch_size: 8
shuffle: True
drop_last: True
collate_batch: False
mosaic_epoch: 285
EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: 416, keep_ratio: True}
- Pad: {size: 416, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 8
TestReader:
inputs_def:
image_shape: [3, 416, 416]
sample_transforms:
- Decode: {}
- Resize: {target_size: 416, keep_ratio: True}
- Pad: {size: 416, fill_value: [114., 114., 114.]}
- Permute: {}
batch_size: 1
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'./_base_/optimizer_300e.yml',
'./_base_/yolox_cspdarknet.yml',
'./_base_/yolox_reader.yml'
]
depth_mult: 1.33
width_mult: 1.25
log_iter: 100
snapshot_epoch: 10
weights: output/yolox_x_300e_coco/model_final
......@@ -3083,6 +3083,8 @@ class Mosaic(BaseOperator):
remove_outside_box=False):
super(Mosaic, self).__init__()
self.prob = prob
if isinstance(input_dim, Integral):
input_dim = [input_dim, input_dim]
self.input_dim = input_dim
self.degrees = degrees
self.translate = translate
......
......@@ -103,8 +103,8 @@ class Trainer(object):
if cfg.architecture == 'YOLOX':
for k, m in self.model.named_sublayers():
if isinstance(m, nn.BatchNorm2D):
m.epsilon = 1e-3 # for amp(fp16)
m.momentum = 0.97 # 0.03 in pytorch
m._epsilon = 1e-3 # for amp(fp16)
m._momentum = 0.97 # 0.03 in pytorch
#normalize params for deploy
if 'slim' in cfg and cfg['slim_type'] == 'OFA':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册