未验证 提交 64865333 编写于 作者: L lvjian0706 提交者: GitHub

[MOT] add mcfairmot ptq and update trainer.py (#5298)

* add mcfiarmot ptq

* update docs
上级 34a74b39
......@@ -44,6 +44,13 @@ PP-tracking provides an AI studio public project tutorial. Please refer to this
- MOTA is the average MOTA of 4 catecories in the VisDrone Vehicle dataset, and this dataset is extracted from the VisDrone2019 MOT dataset, here we provide the download [link](https://bj.bcebos.com/v1/paddledet/data/mot/visdrone_mcmot_vehicle.zip).
- The tracker used in MCFairMOT model here is ByteTracker.
### MCFairMOT off-line quantization results on VisDrone Vehicle val-set
| Model | Compression Strategy | Prediction Delay(T4) |Prediction Delay(V100)| Model Configuration File |Compression Algorithm Configuration File |
| :--------------| :------- | :------: | :----: | :----: | :----: |
| DLA-34 | baseline | 41.3 | 21.9 |[Configuration File](./mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml)| - |
| DLA-34 | off-line quantization | 37.8 | 21.2 |[Configuration File](./mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml)|[Configuration File](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/configs/slim/post_quant/mcfairmot_ptq.yml)|
## Getting Start
### 1. Training
......@@ -95,6 +102,14 @@ python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/mc
- The tracking model is used to predict the video, and does not support the prediction of a single image. The visualization video of the tracking results is saved by default. You can add `--save_mot_txts` to save the txt result file, or `--save_images` to save the visualization images.
- Each line of the tracking results txt file is `frame,id,x1,y1,w,h,score,cls_id,-1,-1`.
### 6. Off-line quantization
The offline quantization model is calibrated using the VisDrone Vehicle val-set, running as:
```bash
CUDA_VISIBLE_DEVICES=0 python3.7 tools/post_quant.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml --slim_config=configs/slim/post_quant/mcfairmot_ptq.yml
```
**Notes:**
- Offline quantization uses the VisDrone Vehicle val-set dataset and a 4-class vehicle tracking model by default.
## Citations
```
......
......@@ -43,6 +43,12 @@ PP-Tracking 提供了AI Studio公开项目案例,教程请参考[PP-Tracking
- MOTA是VisDrone Vehicle数据集4类车辆目标的平均MOTA, 该数据集是VisDrone数据集中抽出4类车辆类别组成的,此处提供数据集[下载链接](https://bj.bcebos.com/v1/paddledet/data/mot/visdrone_mcmot_vehicle.zip)
- MCFairMOT模型此处使用的跟踪器是使用的ByteTracker。
### MCFairMOT 在VisDrone Vehicle val-set上离线量化结果
| 骨干网络 | 压缩策略 | 预测时延(T4) |预测时延(V100)| 配置文件 |压缩算法配置文件 |
| :--------------| :------- | :------: | :----: | :----: | :----: |
| DLA-34 | baseline | 41.3 | 21.9 |[配置文件](./mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml)| - |
| DLA-34 | 离线量化 | 37.8 | 21.2 |[配置文件](./mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml)|[配置文件](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/configs/slim/post_quant/mcfairmot_ptq.yml)|
## 快速开始
### 1. 训练
......@@ -93,6 +99,14 @@ python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/mc
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
- 多类别跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,cls_id,-1,-1`
### 6. 离线量化
使用 VisDrone Vehicle val-set 对离线量化模型进行校准,运行方式:
```bash
CUDA_VISIBLE_DEVICES=0 python3.7 tools/post_quant.py -c configs/mot/mcfairmot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.yml --slim_config=configs/slim/post_quant/mcfairmot_ptq.yml
```
**注意:**
- 离线量化默认使用的是VisDrone Vehicle val-set数据集以及4类车辆跟踪模型。
## 引用
```
......
weights: https://paddledet.bj.bcebos.com/models/mot/mcfairmot_dla34_30e_1088x608_visdrone_vehicle_bytetracker.pdparams
slim: PTQ
PTQ:
ptq_config: {
'activation_quantizer': 'HistQuantizer',
'upsample_bins': 127,
'hist_percent': 0.999}
quant_batch_num: 10
fuse: True
......@@ -73,6 +73,10 @@ class Trainer(object):
logger.error('DeepSORT has no need of training on mot dataset.')
sys.exit(1)
if cfg.architecture == 'FairMOT' and self.mode == 'eval':
images = self.parse_mot_images(cfg)
self.dataset.set_images(images)
if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num)
......@@ -114,14 +118,17 @@ class Trainer(object):
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
reader_name = '{}Reader'.format(self.mode.capitalize())
# If metric is VOC, need to be set collate_batch=False.
if cfg.metric == 'VOC':
cfg[reader_name]['collate_batch'] = False
self.loader = create(reader_name)(self.dataset, cfg.worker_num,
self._eval_batch_sampler)
if cfg.architecture == 'FairMOT':
self.loader = create('EvalMOTReader')(self.dataset, 0)
else:
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
reader_name = '{}Reader'.format(self.mode.capitalize())
# If metric is VOC, need to be set collate_batch=False.
if cfg.metric == 'VOC':
cfg[reader_name]['collate_batch'] = False
self.loader = create(reader_name)(self.dataset, cfg.worker_num,
self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
# build optimizer in train mode
......@@ -759,3 +766,28 @@ class Trainer(object):
flops = flops(self.model, input_spec) / (1000**3)
logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
flops, input_data['image'][0].unsqueeze(0).shape))
def parse_mot_images(self, cfg):
import glob
# for quant
dataset_dir = cfg['EvalMOTDataset'].dataset_dir
data_root = cfg['EvalMOTDataset'].data_root
data_root = '{}/{}'.format(dataset_dir, data_root)
seqs = os.listdir(data_root)
seqs.sort()
all_images = []
for seq in seqs:
infer_dir = os.path.join(data_root, seq)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
images = set()
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
images.sort()
assert len(images) > 0, "no image found in {}".format(infer_dir)
all_images.extend(images)
logger.info("Found {} inference images in total.".format(len(images)))
return all_images
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册