diff --git a/configs/mot/README.md b/configs/mot/README.md index 06c077dc1c9b841b5c88ac7d187334a94526244f..f5debaab4e74f64565f141808ca9bcd403d7dfd8 100644 --- a/configs/mot/README.md +++ b/configs/mot/README.md @@ -6,6 +6,7 @@ English | [简体中文](README_cn.md) - [Introduction](#Introduction) - [Installation](#Installation) - [Model Zoo](#Model_Zoo) +- [Feature Tracking Model](#Feature_Tracking_Model) - [Dataset Preparation](#Dataset_Preparation) - [Getting Start](#Getting_Start) - [Citations](#Citations) @@ -131,6 +132,28 @@ If you use a stronger detection model, you can get better results. Each txt is t FairMOT used 8 GPUs for training and mini-batch size as 6 on each GPU, and trained for 30 epoches. +## Feature Tracking Model + +### 【Head Tracking](./headtracking21/README.md) + +### FairMOT Results on HT-21 Training Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :---: | :----: | :---: | :------: | :----: |:----: | +| DLA-34 | 1088x608 | 67.2 | 70.4 | 9403 | 124840 | 255007 | - | [model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | + +### FairMOT Results on HT-21 Test Set +| backbone | input shape | MOTA | IDF1 | IDS | FP | FN | FPS | download | config | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: |:-------: | :----: | :----: | +| DLA-34 | 1088x608 | 58.2 | 61.3 | 13166 | 141872 | 197074 | - | [model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | + +### [Vehicle Tracking](./kitticars/README.md) +### FairMOT Results on KITTI tracking (2D bounding-boxes) Training Set (Car) + +| backbone | input shape | MOTA | FPS | download | config | +| :--------------| :------- | :-----: | :-----: | :------: | :----: | +| DLA-34 | 1088x608 | 67.9 | - |[model](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml) | + + ## Dataset Preparation ### MOT Dataset diff --git a/configs/mot/README_cn.md b/configs/mot/README_cn.md index 167f589e89c56532ee941f4383dea5aa9b333a96..ff430d899075be9e398318f821c38e2d09f23bf2 100644 --- a/configs/mot/README_cn.md +++ b/configs/mot/README_cn.md @@ -7,6 +7,7 @@ - [安装依赖](#安装依赖) - [模型库](#模型库) - [数据集准备](#数据集准备) +- [特色垂类跟踪模型](#特色垂类跟踪模型) - [快速开始](#快速开始) - [引用](#引用) @@ -131,6 +132,28 @@ wget https://dataset.bj.bcebos.com/mot/det_results_dir.zip FairMOT使用8个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 +## 特色垂类跟踪模型 + +### 【人头跟踪(Head Tracking)](./headtracking21/README.md) + +### FairMOT在HT-21 Training Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :---: | :----: | :---: | :------: | :----: |:----: | +| DLA-34 | 1088x608 | 67.2 | 70.4 | 9403 | 124840 | 255007 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | + +### FairMOT在HT-21 Test Set上结果 +| 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :----: | :----: | :----: | :----: | :----: |:-------: | :----: | :----: | +| DLA-34 | 1088x608 | 58.2 | 61.3 | 13166 | 141872 | 197074 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | + +### [车辆跟踪 (Vehicle Tracking)](./kitticars/README.md) +### FairMOT在KITTI tracking (2D bounding-boxes) Training Set上Car类别的结果 + +| 骨干网络 | 输入尺寸 | MOTA | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :-----: | :-----: | :------: | :----: | +| DLA-34 | 1088x608 | 67.9 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml) | + + ## 数据集准备 ### MOT数据集 diff --git a/configs/mot/headtracking21/README.md b/configs/mot/headtracking21/README.md new file mode 120000 index 0000000000000000000000000000000000000000..4015683cfa5969297febc12e7ca1264afabbc0b5 --- /dev/null +++ b/configs/mot/headtracking21/README.md @@ -0,0 +1 @@ +README_cn.md \ No newline at end of file diff --git a/configs/mot/fairmot/headtracking21/README_cn.md b/configs/mot/headtracking21/README_cn.md similarity index 62% rename from configs/mot/fairmot/headtracking21/README_cn.md rename to configs/mot/headtracking21/README_cn.md index 4a82832e1ad0a846b58ccd2c7dadf7e34aa8c2ec..281c5ab57143b4b4019985d6dd69a8d0238ac29b 100644 --- a/configs/mot/fairmot/headtracking21/README_cn.md +++ b/configs/mot/headtracking21/README_cn.md @@ -1,35 +1,24 @@ -简体中文 | [English](README.md) +[English](README.md) | 简体中文 +# 特色垂类跟踪模型 -# FairMOT (FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking) - -## 内容 -- [简介](#简介) -- [模型库](#模型库) -- [快速开始](#快速开始) -- [引用](#引用) - -## 简介 - -[FairMOT](https://arxiv.org/abs/2004.01888)以Anchor Free的CenterNet检测器为基础,克服了Anchor-Based的检测框架中anchor和特征不对齐问题,深浅层特征融合使得检测和ReID任务各自获得所需要的特征,并且使用低维度ReID特征,提出了一种由两个同质分支组成的简单baseline来预测像素级目标得分和ReID特征,实现了两个任务之间的公平性,并获得了更高水平的实时多目标跟踪精度。 +## 人头跟踪(Head Tracking) +现有行人跟踪器对高人群密度场景表现不佳,人头跟踪更适用于密集场景的跟踪。 +[HT-21](https://motchallenge.net/data/Head_Tracking_21)是一个高人群密度拥挤场景的人头跟踪数据集,场景包括不同的光线和环境条件下的拥挤的室内和室外场景,所有序列的帧速率都是25fps。
- +
## 模型库 - ### FairMOT在HT-21 Training Set上结果 - | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :---: | :----: | :---: | :------: | :----: |:----: | -| DLA-34 | 1088x608 | 67.2 | 70.4 | 9403 | 124840 | 255007 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | - +| DLA-34 | 1088x608 | 67.2 | 70.4 | 9403 | 124840 | 255007 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | ### FairMOT在HT-21 Test Set上结果 - | 骨干网络 | 输入尺寸 | MOTA | IDF1 | IDS | FP | FN | FPS | 下载链接 | 配置文件 | | :--------------| :------- | :----: | :----: | :----: | :----: | :----: |:-------: | :----: | :----: | -| DLA-34 | 1088x608 | 58.2 | 61.3 | 13166 | 141872 | 197074 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | +| DLA-34 | 1088x608 | 58.2 | 61.3 | 13166 | 141872 | 197074 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml) | **注意:** FairMOT使用8个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 @@ -37,53 +26,42 @@ ## 快速开始 ### 1. 训练 - 使用8GPU通过如下命令一键式启动训练 - ```bash -python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_headtracking21/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml +python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_headtracking21/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml ``` - ### 2. 评估 - 使用单张GPU通过如下命令一键式启动评估 - ```bash # 使用PaddleDetection发布的权重 -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams # 使用训练保存的checkpoint -CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=output/fairmot_dla34_30e_1088x608_headtracking21/model_final.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=output/fairmot_dla34_30e_1088x608_headtracking21/model_final.pdparams ``` - ### 3. 预测 - 使用单个GPU通过如下命令预测一个视频,并保存为视频 - ```bash # 预测一个视频 -CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams --video_file={your video name}.mp4 --save_videos +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams --video_file={your video name}.mp4 --save_videos ``` **注意:** 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 ### 4. 导出预测模型 - ```bash -CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_headtracking21.pdparams ``` ### 5. 用导出的模型基于Python去预测 - ```bash python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_headtracking21 --video_file={your video name}.mp4 --device=GPU --save_mot_txts ``` **注意:** 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 - ## 引用 ``` @article{zhang2020fair, diff --git a/configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml b/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml similarity index 80% rename from configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml rename to configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml index ac3aa657d565d284f123153a3e6f4db032098b6e..8bfbc7ca8b7b76f4b0dbab42999cd6e15f392aaa 100755 --- a/configs/mot/fairmot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml +++ b/configs/mot/headtracking21/fairmot_dla34_30e_1088x608_headtracking21.yml @@ -1,7 +1,8 @@ _BASE_: [ - '../fairmot_dla34_30e_1088x608.yml' + '../fairmot/fairmot_dla34_30e_1088x608.yml' ] +weights: output/fairmot_dla34_30e_1088x608_headtracking21/model_final # for MOT training TrainDataset: @@ -11,7 +12,7 @@ TrainDataset: data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] # for MOT evaluation -# If you want to change the MOT evaluation dataset, please modify 'task' and 'data_root' +# If you want to change the MOT evaluation dataset, please modify 'data_root' EvalMOTDataset: !MOTImageFolder dataset_dir: dataset/mot @@ -20,6 +21,6 @@ EvalMOTDataset: # for MOT video inference TestMOTDataset: - !MOTVideoDataset + !MOTImageFolder dataset_dir: dataset/mot keep_ori_im: True # set True if save visualization images or video diff --git a/configs/mot/kitticars/README.md b/configs/mot/kitticars/README.md new file mode 120000 index 0000000000000000000000000000000000000000..4015683cfa5969297febc12e7ca1264afabbc0b5 --- /dev/null +++ b/configs/mot/kitticars/README.md @@ -0,0 +1 @@ +README_cn.md \ No newline at end of file diff --git a/configs/mot/kitticars/README_cn.md b/configs/mot/kitticars/README_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..bd9f0bade6ba83f5efe788238fc6b961c12e2704 --- /dev/null +++ b/configs/mot/kitticars/README_cn.md @@ -0,0 +1,76 @@ +[English](README.md) | 简体中文 +# 特色垂类跟踪模型 + +## 车辆跟踪 (Vehicle Tracking) + +车辆跟踪的主要应用之一是交通监控。 +[KITTI-Tracking](http://www.cvlibs.net/datasets/kitti/eval_tracking.php)是一个包含市区、乡村和高速公路等场景采集的数据集,每张图像中最多达15辆车和30个行人,还有各种程度的遮挡与截断。其中用于目标跟踪的数据集一共有50个视频序列,21个为训练集,29个为测试集,目标是估计类别“Car”和”Pedestrian“的目标轨迹,此处只使用类别“Car”。 +
+ +
+ +## 模型库 + +### FairMOT在KITTI tracking (2D bounding-boxes) Training Set上Car类别的结果 + +| 骨干网络 | 输入尺寸 | MOTA | FPS | 下载链接 | 配置文件 | +| :--------------| :------- | :-----: | :-----: | :------: | :----: | +| DLA-34 | 1088x608 | 67.9 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.2/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml) | + +**注意:** + FairMOT使用8个GPU进行训练,每个GPU上batch size为6,训练30个epoch。 + +## 快速开始 + +### 1. 训练 +使用8GPU通过如下命令一键式启动训练 +```bash +python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_kitticars/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml +``` + +### 2. 评估 +使用单张GPU通过如下命令一键式启动评估 +```bash +# 使用PaddleDetection发布的权重 +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams + +# 使用训练保存的checkpoint +CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=output/fairmot_dla34_30e_1088x608_kitticars/model_final.pdparams +``` + +### 3. 预测 +使用单个GPU通过如下命令预测一个视频,并保存为视频 +```bash +# 预测一个视频 +CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams --video_file={your video name}.mp4 --save_videos +``` +**注意:** + 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。 + +### 4. 导出预测模型 +```bash +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_kitticars.pdparams +``` + +### 5. 用导出的模型基于Python去预测 +```bash +python deploy/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_kitticars --video_file={your video name}.mp4 --device=GPU --save_mot_txts +``` +**注意:** + 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。 + +## 引用 +``` +@article{zhang2020fair, + title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking}, + author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu}, + journal={arXiv preprint arXiv:2004.01888}, + year={2020} +} +@INPROCEEDINGS{Geiger2012CVPR, + author = {Andreas Geiger and Philip Lenz and Raquel Urtasun}, + title = {Are we ready for Autonomous Driving? The KITTI Vision Benchmark Suite}, + booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR)}, + year = {2012} +} +``` diff --git a/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml b/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml new file mode 100755 index 0000000000000000000000000000000000000000..a103cba6c3c17242e939d72cdd92319486577d9e --- /dev/null +++ b/configs/mot/kitticars/fairmot_dla34_30e_1088x608_kitticars.yml @@ -0,0 +1,27 @@ +_BASE_: [ + '../fairmot/fairmot_dla34_30e_1088x608.yml' +] + +metric: KITTI +weights: output/fairmot_dla34_30e_1088x608_kitticars/model_final + +# for MOT training +TrainDataset: + !MOTDataSet + dataset_dir: dataset/mot + image_lists: ['kitticars.train'] + data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide'] + +# for MOT evaluation +# If you want to change the MOT evaluation dataset, please modify 'data_root' +EvalMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + data_root: kitticars/images/test + keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT + +# for MOT video inference +TestMOTDataset: + !MOTImageFolder + dataset_dir: dataset/mot + keep_ori_im: True # set True if save visualization images or video diff --git a/deploy/python/mot_jde_infer.py b/deploy/python/mot_jde_infer.py index 021e5aea7070e44d8c42eff6f3741caa702c7797..103cf45855ccdf8bebe55f6230c04b9b3b132b5e 100644 --- a/deploy/python/mot_jde_infer.py +++ b/deploy/python/mot_jde_infer.py @@ -179,6 +179,7 @@ def write_mot_results(filename, results, data_type='mot'): def predict_image(detector, image_list): results = [] + image_list.sort() for i, img_file in enumerate(image_list): frame = cv2.imread(img_file) if FLAGS.run_benchmark: diff --git a/deploy/python/mot_keypoint_unite_infer.py b/deploy/python/mot_keypoint_unite_infer.py index 46ecefe78cbbdb3a6502466e0e62b118f7bb1595..2db100c2d0ac6b31c5e0c29f58933eb94e485153 100644 --- a/deploy/python/mot_keypoint_unite_infer.py +++ b/deploy/python/mot_keypoint_unite_infer.py @@ -56,6 +56,7 @@ def mot_keypoint_unite_predict_image(mot_model, keypoint_model, image_list, keypoint_batch_size=1): + image_list.sort() for i, img_file in enumerate(image_list): frame = cv2.imread(img_file) diff --git a/deploy/python/mot_sde_infer.py b/deploy/python/mot_sde_infer.py index 55dbef8c9ebc23ac94d7d440cada852e6361ee7e..5dfb944c67259f9826cc7a4d971c4dc450c0fb14 100644 --- a/deploy/python/mot_sde_infer.py +++ b/deploy/python/mot_sde_infer.py @@ -297,6 +297,7 @@ class SDE_ReID(object): def predict_image(detector, reid_model, image_list): results = [] + image_list.sort() for i, img_file in enumerate(image_list): frame = cv2.imread(img_file) if FLAGS.run_benchmark: diff --git a/docs/images/kitticars_fairmot.gif b/docs/images/kitticars_fairmot.gif new file mode 100644 index 0000000000000000000000000000000000000000..84dd36964aba7068cd756b17c1f191cc14b18537 Binary files /dev/null and b/docs/images/kitticars_fairmot.gif differ diff --git a/ppdet/engine/tracker.py b/ppdet/engine/tracker.py index 5e738370992f84239a20c88801fa2adc14f1e7a5..07fe38800084a0d6e483000c96d7edde60d659f7 100644 --- a/ppdet/engine/tracker.py +++ b/ppdet/engine/tracker.py @@ -28,7 +28,7 @@ from ppdet.modeling.mot.utils import Detection, get_crops, scale_coords, clip_bo from ppdet.modeling.mot.utils import Timer, load_det_results from ppdet.modeling.mot import visualization as mot_vis -from ppdet.metrics import Metric, MOTMetric +from ppdet.metrics import Metric, MOTMetric, KITTIMOTMetric import ppdet.utils.stats as stats from .callbacks import Callback, ComposeCallback @@ -74,6 +74,8 @@ class Tracker(object): if self.cfg.metric == 'MOT': self._metrics = [MOTMetric(), ] + elif self.cfg.metric == 'KITTI': + self._metrics = [KITTIMOTMetric(), ] else: logger.warning("Metric not support for metric type {}".format( self.cfg.metric)) @@ -329,7 +331,7 @@ class Tracker(object): if save_videos: output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq)) - cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format( + cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format( save_dir, output_video_path) os.system(cmd_str) logger.info('Save video in {}.'.format(output_video_path)) @@ -445,7 +447,7 @@ class Tracker(object): if save_videos: output_video_path = os.path.join(save_dir, '..', '{}_vis.mp4'.format(seq)) - cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format( + cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format( save_dir, output_video_path) os.system(cmd_str) logger.info('Save video in {}'.format(output_video_path)) @@ -454,7 +456,7 @@ class Tracker(object): if data_type in ['mot', 'mcmot', 'lab']: save_format = '{frame},{id},{x1},{y1},{w},{h},{score},-1,-1,-1\n' elif data_type == 'kitti': - save_format = '{frame} {id} pedestrian 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n' + save_format = '{frame} {id} car 0 0 -10 {x1} {y1} {x2} {y2} -10 -10 -10 -1000 -1000 -1000 -10\n' else: raise ValueError(data_type) diff --git a/ppdet/metrics/mot_metrics.py b/ppdet/metrics/mot_metrics.py index 2b918c5426e6f8898726ac2c4c0e9645cb10f59d..e70c0bd31f03a36576230493747d1d32c54feb10 100644 --- a/ppdet/metrics/mot_metrics.py +++ b/ppdet/metrics/mot_metrics.py @@ -11,23 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import copy +import sys +import math +from collections import defaultdict import numpy as np import paddle import paddle.nn.functional as F from ppdet.modeling.bbox_utils import bbox_iou_np_expand from .map_utils import ap_per_class from .metrics import Metric +from .munkres import Munkres from ppdet.utils.logger import setup_logger logger = setup_logger(__name__) -__all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric'] +__all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric', 'KITTIMOTMetric'] def read_mot_results(filename, is_gt=False, is_ignore=False): @@ -74,6 +79,7 @@ def read_mot_results(filename, is_gt=False, is_ignore=False): """ +MOT dataset label list, see in https://motchallenge.net labels={'ped', ... % 1 'person_on_vhcl', ... % 2 'car', ... % 3 @@ -302,3 +308,926 @@ class JDEDetMetric(Metric): def get_results(self): return self.map_stat + + +""" +Following code is borrow from https://github.com/xingyizhou/CenterTrack/blob/master/src/tools/eval_kitti_track/evaluate_tracking.py +""" + + +class tData: + """ + Utility class to load data. + """ + def __init__(self,frame=-1,obj_type="unset",truncation=-1,occlusion=-1,\ + obs_angle=-10,x1=-1,y1=-1,x2=-1,y2=-1,w=-1,h=-1,l=-1,\ + X=-1000,Y=-1000,Z=-1000,yaw=-10,score=-1000,track_id=-1): + """ + Constructor, initializes the object given the parameters. + """ + self.frame = frame + self.track_id = track_id + self.obj_type = obj_type + self.truncation = truncation + self.occlusion = occlusion + self.obs_angle = obs_angle + self.x1 = x1 + self.y1 = y1 + self.x2 = x2 + self.y2 = y2 + self.w = w + self.h = h + self.l = l + self.X = X + self.Y = Y + self.Z = Z + self.yaw = yaw + self.score = score + self.ignored = False + self.valid = False + self.tracker = -1 + + def __str__(self): + attrs = vars(self) + return '\n'.join("%s: %s" % item for item in attrs.items()) + + +class KITTIEvaluation(object): + """ KITTI tracking statistics (CLEAR MOT, id-switches, fragments, ML/PT/MT, precision/recall) + MOTA - Multi-object tracking accuracy in [0,100] + MOTP - Multi-object tracking precision in [0,100] (3D) / [td,100] (2D) + MOTAL - Multi-object tracking accuracy in [0,100] with log10(id-switches) + + id-switches - number of id switches + fragments - number of fragmentations + + MT, PT, ML - number of mostly tracked, partially tracked and mostly lost trajectories + + recall - recall = percentage of detected targets + precision - precision = percentage of correctly detected targets + FAR - number of false alarms per frame + falsepositives - number of false positives (FP) + missed - number of missed targets (FN) + """ + def __init__(self, result_path, gt_path, min_overlap=0.5, max_truncation = 0,\ + min_height = 25, max_occlusion = 2, cls="car",\ + n_frames=[], seqs=[], n_sequences=0): + # get number of sequences and + # get number of frames per sequence from test mapping + # (created while extracting the benchmark) + self.gt_path = os.path.join(gt_path, "label_02") + self.n_frames = n_frames + self.sequence_name = seqs + self.n_sequences = n_sequences + + self.cls = cls # class to evaluate, i.e. pedestrian or car + + self.result_path = result_path + + # statistics and numbers for evaluation + self.n_gt = 0 # number of ground truth detections minus ignored false negatives and true positives + self.n_igt = 0 # number of ignored ground truth detections + self.n_gts = [ + ] # number of ground truth detections minus ignored false negatives and true positives PER SEQUENCE + self.n_igts = [ + ] # number of ground ignored truth detections PER SEQUENCE + self.n_gt_trajectories = 0 + self.n_gt_seq = [] + self.n_tr = 0 # number of tracker detections minus ignored tracker detections + self.n_trs = [ + ] # number of tracker detections minus ignored tracker detections PER SEQUENCE + self.n_itr = 0 # number of ignored tracker detections + self.n_itrs = [] # number of ignored tracker detections PER SEQUENCE + self.n_igttr = 0 # number of ignored ground truth detections where the corresponding associated tracker detection is also ignored + self.n_tr_trajectories = 0 + self.n_tr_seq = [] + self.MOTA = 0 + self.MOTP = 0 + self.MOTAL = 0 + self.MODA = 0 + self.MODP = 0 + self.MODP_t = [] + self.recall = 0 + self.precision = 0 + self.F1 = 0 + self.FAR = 0 + self.total_cost = 0 + self.itp = 0 # number of ignored true positives + self.itps = [] # number of ignored true positives PER SEQUENCE + self.tp = 0 # number of true positives including ignored true positives! + self.tps = [ + ] # number of true positives including ignored true positives PER SEQUENCE + self.fn = 0 # number of false negatives WITHOUT ignored false negatives + self.fns = [ + ] # number of false negatives WITHOUT ignored false negatives PER SEQUENCE + self.ifn = 0 # number of ignored false negatives + self.ifns = [] # number of ignored false negatives PER SEQUENCE + self.fp = 0 # number of false positives + # a bit tricky, the number of ignored false negatives and ignored true positives + # is subtracted, but if both tracker detection and ground truth detection + # are ignored this number is added again to avoid double counting + self.fps = [] # above PER SEQUENCE + self.mme = 0 + self.fragments = 0 + self.id_switches = 0 + self.MT = 0 + self.PT = 0 + self.ML = 0 + + self.min_overlap = min_overlap # minimum bounding box overlap for 3rd party metrics + self.max_truncation = max_truncation # maximum truncation of an object for evaluation + self.max_occlusion = max_occlusion # maximum occlusion of an object for evaluation + self.min_height = min_height # minimum height of an object for evaluation + self.n_sample_points = 500 + + # this should be enough to hold all groundtruth trajectories + # is expanded if necessary and reduced in any case + self.gt_trajectories = [[] for x in range(self.n_sequences)] + self.ign_trajectories = [[] for x in range(self.n_sequences)] + + def loadGroundtruth(self): + try: + self._loadData(self.gt_path, cls=self.cls, loading_groundtruth=True) + except IOError: + return False + return True + + def loadTracker(self): + try: + if not self._loadData( + self.result_path, cls=self.cls, loading_groundtruth=False): + return False + except IOError: + return False + return True + + def _loadData(self, + root_dir, + cls, + min_score=-1000, + loading_groundtruth=False): + """ + Generic loader for ground truth and tracking data. + Use loadGroundtruth() or loadTracker() to load this data. + Loads detections in KITTI format from textfiles. + """ + # construct objectDetections object to hold detection data + t_data = tData() + data = [] + eval_2d = True + eval_3d = True + + seq_data = [] + n_trajectories = 0 + n_trajectories_seq = [] + for seq, s_name in enumerate(self.sequence_name): + i = 0 + filename = os.path.join(root_dir, "%s.txt" % s_name) + f = open(filename, "r") + + f_data = [ + [] for x in range(self.n_frames[seq]) + ] # current set has only 1059 entries, sufficient length is checked anyway + ids = [] + n_in_seq = 0 + id_frame_cache = [] + for line in f: + # KITTI tracking benchmark data format: + # (frame,tracklet_id,objectType,truncation,occlusion,alpha,x1,y1,x2,y2,h,w,l,X,Y,Z,ry) + line = line.strip() + fields = line.split(" ") + # classes that should be loaded (ignored neighboring classes) + if "car" in cls.lower(): + classes = ["car", "van"] + elif "pedestrian" in cls.lower(): + classes = ["pedestrian", "person_sitting"] + else: + classes = [cls.lower()] + classes += ["dontcare"] + if not any([s for s in classes if s in fields[2].lower()]): + continue + # get fields from table + t_data.frame = int(float(fields[0])) # frame + t_data.track_id = int(float(fields[1])) # id + t_data.obj_type = fields[ + 2].lower() # object type [car, pedestrian, cyclist, ...] + t_data.truncation = int( + float(fields[3])) # truncation [-1,0,1,2] + t_data.occlusion = int( + float(fields[4])) # occlusion [-1,0,1,2] + t_data.obs_angle = float(fields[5]) # observation angle [rad] + t_data.x1 = float(fields[6]) # left [px] + t_data.y1 = float(fields[7]) # top [px] + t_data.x2 = float(fields[8]) # right [px] + t_data.y2 = float(fields[9]) # bottom [px] + t_data.h = float(fields[10]) # height [m] + t_data.w = float(fields[11]) # width [m] + t_data.l = float(fields[12]) # length [m] + t_data.X = float(fields[13]) # X [m] + t_data.Y = float(fields[14]) # Y [m] + t_data.Z = float(fields[15]) # Z [m] + t_data.yaw = float(fields[16]) # yaw angle [rad] + if not loading_groundtruth: + if len(fields) == 17: + t_data.score = -1 + elif len(fields) == 18: + t_data.score = float(fields[17]) # detection score + else: + logger.info("file is not in KITTI format") + return + + # do not consider objects marked as invalid + if t_data.track_id is -1 and t_data.obj_type != "dontcare": + continue + + idx = t_data.frame + # check if length for frame data is sufficient + if idx >= len(f_data): + print("extend f_data", idx, len(f_data)) + f_data += [[] for x in range(max(500, idx - len(f_data)))] + try: + id_frame = (t_data.frame, t_data.track_id) + if id_frame in id_frame_cache and not loading_groundtruth: + logger.info( + "track ids are not unique for sequence %d: frame %d" + % (seq, t_data.frame)) + logger.info( + "track id %d occured at least twice for this frame" + % t_data.track_id) + logger.info("Exiting...") + #continue # this allows to evaluate non-unique result files + return False + id_frame_cache.append(id_frame) + f_data[t_data.frame].append(copy.copy(t_data)) + except: + print(len(f_data), idx) + raise + + if t_data.track_id not in ids and t_data.obj_type != "dontcare": + ids.append(t_data.track_id) + n_trajectories += 1 + n_in_seq += 1 + + # check if uploaded data provides information for 2D and 3D evaluation + if not loading_groundtruth and eval_2d is True and ( + t_data.x1 == -1 or t_data.x2 == -1 or t_data.y1 == -1 or + t_data.y2 == -1): + eval_2d = False + if not loading_groundtruth and eval_3d is True and ( + t_data.X == -1000 or t_data.Y == -1000 or + t_data.Z == -1000): + eval_3d = False + + # only add existing frames + n_trajectories_seq.append(n_in_seq) + seq_data.append(f_data) + f.close() + + if not loading_groundtruth: + self.tracker = seq_data + self.n_tr_trajectories = n_trajectories + self.eval_2d = eval_2d + self.eval_3d = eval_3d + self.n_tr_seq = n_trajectories_seq + if self.n_tr_trajectories == 0: + return False + else: + # split ground truth and DontCare areas + self.dcareas = [] + self.groundtruth = [] + for seq_idx in range(len(seq_data)): + seq_gt = seq_data[seq_idx] + s_g, s_dc = [], [] + for f in range(len(seq_gt)): + all_gt = seq_gt[f] + g, dc = [], [] + for gg in all_gt: + if gg.obj_type == "dontcare": + dc.append(gg) + else: + g.append(gg) + s_g.append(g) + s_dc.append(dc) + self.dcareas.append(s_dc) + self.groundtruth.append(s_g) + self.n_gt_seq = n_trajectories_seq + self.n_gt_trajectories = n_trajectories + return True + + def boxoverlap(self, a, b, criterion="union"): + """ + boxoverlap computes intersection over union for bbox a and b in KITTI format. + If the criterion is 'union', overlap = (a inter b) / a union b). + If the criterion is 'a', overlap = (a inter b) / a, where b should be a dontcare area. + """ + x1 = max(a.x1, b.x1) + y1 = max(a.y1, b.y1) + x2 = min(a.x2, b.x2) + y2 = min(a.y2, b.y2) + + w = x2 - x1 + h = y2 - y1 + + if w <= 0. or h <= 0.: + return 0. + inter = w * h + aarea = (a.x2 - a.x1) * (a.y2 - a.y1) + barea = (b.x2 - b.x1) * (b.y2 - b.y1) + # intersection over union overlap + if criterion.lower() == "union": + o = inter / float(aarea + barea - inter) + elif criterion.lower() == "a": + o = float(inter) / float(aarea) + else: + raise TypeError("Unkown type for criterion") + return o + + def compute3rdPartyMetrics(self): + """ + Computes the metrics defined in + - Stiefelhagen 2008: Evaluating Multiple Object Tracking Performance: The CLEAR MOT Metrics + MOTA, MOTAL, MOTP + - Nevatia 2008: Global Data Association for Multi-Object Tracking Using Network Flows + MT/PT/ML + """ + # construct Munkres object for Hungarian Method association + hm = Munkres() + max_cost = 1e9 + + # go through all frames and associate ground truth and tracker results + # groundtruth and tracker contain lists for every single frame containing lists of KITTI format detections + fr, ids = 0, 0 + for seq_idx in range(len(self.groundtruth)): + seq_gt = self.groundtruth[seq_idx] + seq_dc = self.dcareas[seq_idx] # don't care areas + seq_tracker = self.tracker[seq_idx] + seq_trajectories = defaultdict(list) + seq_ignored = defaultdict(list) + + # statistics over the current sequence, check the corresponding + # variable comments in __init__ to get their meaning + seqtp = 0 + seqitp = 0 + seqfn = 0 + seqifn = 0 + seqfp = 0 + seqigt = 0 + seqitr = 0 + + last_ids = [[], []] + n_gts = 0 + n_trs = 0 + + for f in range(len(seq_gt)): + g = seq_gt[f] + dc = seq_dc[f] + + t = seq_tracker[f] + # counting total number of ground truth and tracker objects + self.n_gt += len(g) + self.n_tr += len(t) + + n_gts += len(g) + n_trs += len(t) + + # use hungarian method to associate, using boxoverlap 0..1 as cost + # build cost matrix + cost_matrix = [] + this_ids = [[], []] + for gg in g: + # save current ids + this_ids[0].append(gg.track_id) + this_ids[1].append(-1) + gg.tracker = -1 + gg.id_switch = 0 + gg.fragmentation = 0 + cost_row = [] + for tt in t: + # overlap == 1 is cost ==0 + c = 1 - self.boxoverlap(gg, tt) + # gating for boxoverlap + if c <= self.min_overlap: + cost_row.append(c) + else: + cost_row.append(max_cost) # = 1e9 + cost_matrix.append(cost_row) + # all ground truth trajectories are initially not associated + # extend groundtruth trajectories lists (merge lists) + seq_trajectories[gg.track_id].append(-1) + seq_ignored[gg.track_id].append(False) + + if len(g) is 0: + cost_matrix = [[]] + # associate + association_matrix = hm.compute(cost_matrix) + + # tmp variables for sanity checks and MODP computation + tmptp = 0 + tmpfp = 0 + tmpfn = 0 + tmpc = 0 # this will sum up the overlaps for all true positives + tmpcs = [0] * len( + g) # this will save the overlaps for all true positives + # the reason is that some true positives might be ignored + # later such that the corrsponding overlaps can + # be subtracted from tmpc for MODP computation + + # mapping for tracker ids and ground truth ids + for row, col in association_matrix: + # apply gating on boxoverlap + c = cost_matrix[row][col] + if c < max_cost: + g[row].tracker = t[col].track_id + this_ids[1][row] = t[col].track_id + t[col].valid = True + g[row].distance = c + self.total_cost += 1 - c + tmpc += 1 - c + tmpcs[row] = 1 - c + seq_trajectories[g[row].track_id][-1] = t[col].track_id + + # true positives are only valid associations + self.tp += 1 + tmptp += 1 + else: + g[row].tracker = -1 + self.fn += 1 + tmpfn += 1 + + # associate tracker and DontCare areas + # ignore tracker in neighboring classes + nignoredtracker = 0 # number of ignored tracker detections + ignoredtrackers = dict() # will associate the track_id with -1 + # if it is not ignored and 1 if it is + # ignored; + # this is used to avoid double counting ignored + # cases, see the next loop + + for tt in t: + ignoredtrackers[tt.track_id] = -1 + # ignore detection if it belongs to a neighboring class or is + # smaller or equal to the minimum height + + tt_height = abs(tt.y1 - tt.y2) + if ((self.cls == "car" and tt.obj_type == "van") or + (self.cls == "pedestrian" and + tt.obj_type == "person_sitting") or + tt_height <= self.min_height) and not tt.valid: + nignoredtracker += 1 + tt.ignored = True + ignoredtrackers[tt.track_id] = 1 + continue + for d in dc: + overlap = self.boxoverlap(tt, d, "a") + if overlap > 0.5 and not tt.valid: + tt.ignored = True + nignoredtracker += 1 + ignoredtrackers[tt.track_id] = 1 + break + + # check for ignored FN/TP (truncation or neighboring object class) + ignoredfn = 0 # the number of ignored false negatives + nignoredtp = 0 # the number of ignored true positives + nignoredpairs = 0 # the number of ignored pairs, i.e. a true positive + # which is ignored but where the associated tracker + # detection has already been ignored + + gi = 0 + for gg in g: + if gg.tracker < 0: + if gg.occlusion>self.max_occlusion or gg.truncation>self.max_truncation\ + or (self.cls=="car" and gg.obj_type=="van") or (self.cls=="pedestrian" and gg.obj_type=="person_sitting"): + seq_ignored[gg.track_id][-1] = True + gg.ignored = True + ignoredfn += 1 + + elif gg.tracker >= 0: + if gg.occlusion>self.max_occlusion or gg.truncation>self.max_truncation\ + or (self.cls=="car" and gg.obj_type=="van") or (self.cls=="pedestrian" and gg.obj_type=="person_sitting"): + + seq_ignored[gg.track_id][-1] = True + gg.ignored = True + nignoredtp += 1 + + # if the associated tracker detection is already ignored, + # we want to avoid double counting ignored detections + if ignoredtrackers[gg.tracker] > 0: + nignoredpairs += 1 + + # for computing MODP, the overlaps from ignored detections + # are subtracted + tmpc -= tmpcs[gi] + gi += 1 + + # the below might be confusion, check the comments in __init__ + # to see what the individual statistics represent + + # correct TP by number of ignored TP due to truncation + # ignored TP are shown as tracked in visualization + tmptp -= nignoredtp + + # count the number of ignored true positives + self.itp += nignoredtp + + # adjust the number of ground truth objects considered + self.n_gt -= (ignoredfn + nignoredtp) + + # count the number of ignored ground truth objects + self.n_igt += ignoredfn + nignoredtp + + # count the number of ignored tracker objects + self.n_itr += nignoredtracker + + # count the number of ignored pairs, i.e. associated tracker and + # ground truth objects that are both ignored + self.n_igttr += nignoredpairs + + # false negatives = associated gt bboxes exceding association threshold + non-associated gt bboxes + tmpfn += len(g) - len(association_matrix) - ignoredfn + self.fn += len(g) - len(association_matrix) - ignoredfn + self.ifn += ignoredfn + + # false positives = tracker bboxes - associated tracker bboxes + # mismatches (mme_t) + tmpfp += len( + t) - tmptp - nignoredtracker - nignoredtp + nignoredpairs + self.fp += len( + t) - tmptp - nignoredtracker - nignoredtp + nignoredpairs + + # update sequence data + seqtp += tmptp + seqitp += nignoredtp + seqfp += tmpfp + seqfn += tmpfn + seqifn += ignoredfn + seqigt += ignoredfn + nignoredtp + seqitr += nignoredtracker + + # sanity checks + # - the number of true positives minues ignored true positives + # should be greater or equal to 0 + # - the number of false negatives should be greater or equal to 0 + # - the number of false positives needs to be greater or equal to 0 + # otherwise ignored detections might be counted double + # - the number of counted true positives (plus ignored ones) + # and the number of counted false negatives (plus ignored ones) + # should match the total number of ground truth objects + # - the number of counted true positives (plus ignored ones) + # and the number of counted false positives + # plus the number of ignored tracker detections should + # match the total number of tracker detections; note that + # nignoredpairs is subtracted here to avoid double counting + # of ignored detection sin nignoredtp and nignoredtracker + if tmptp < 0: + print(tmptp, nignoredtp) + raise NameError("Something went wrong! TP is negative") + if tmpfn < 0: + print(tmpfn, + len(g), + len(association_matrix), ignoredfn, nignoredpairs) + raise NameError("Something went wrong! FN is negative") + if tmpfp < 0: + print(tmpfp, + len(t), tmptp, nignoredtracker, nignoredtp, + nignoredpairs) + raise NameError("Something went wrong! FP is negative") + if tmptp + tmpfn is not len(g) - ignoredfn - nignoredtp: + print("seqidx", seq_idx) + print("frame ", f) + print("TP ", tmptp) + print("FN ", tmpfn) + print("FP ", tmpfp) + print("nGT ", len(g)) + print("nAss ", len(association_matrix)) + print("ign GT", ignoredfn) + print("ign TP", nignoredtp) + raise NameError( + "Something went wrong! nGroundtruth is not TP+FN") + if tmptp + tmpfp + nignoredtp + nignoredtracker - nignoredpairs is not len( + t): + print(seq_idx, f, len(t), tmptp, tmpfp) + print(len(association_matrix), association_matrix) + raise NameError( + "Something went wrong! nTracker is not TP+FP") + + # check for id switches or fragmentations + for i, tt in enumerate(this_ids[0]): + if tt in last_ids[0]: + idx = last_ids[0].index(tt) + tid = this_ids[1][i] + lid = last_ids[1][idx] + if tid != lid and lid != -1 and tid != -1: + if g[i].truncation < self.max_truncation: + g[i].id_switch = 1 + ids += 1 + if tid != lid and lid != -1: + if g[i].truncation < self.max_truncation: + g[i].fragmentation = 1 + fr += 1 + + # save current index + last_ids = this_ids + # compute MOTP_t + MODP_t = 1 + if tmptp != 0: + MODP_t = tmpc / float(tmptp) + self.MODP_t.append(MODP_t) + + # remove empty lists for current gt trajectories + self.gt_trajectories[seq_idx] = seq_trajectories + self.ign_trajectories[seq_idx] = seq_ignored + + # gather statistics for "per sequence" statistics. + self.n_gts.append(n_gts) + self.n_trs.append(n_trs) + self.tps.append(seqtp) + self.itps.append(seqitp) + self.fps.append(seqfp) + self.fns.append(seqfn) + self.ifns.append(seqifn) + self.n_igts.append(seqigt) + self.n_itrs.append(seqitr) + + # compute MT/PT/ML, fragments, idswitches for all groundtruth trajectories + n_ignored_tr_total = 0 + for seq_idx, ( + seq_trajectories, seq_ignored + ) in enumerate(zip(self.gt_trajectories, self.ign_trajectories)): + if len(seq_trajectories) == 0: + continue + tmpMT, tmpML, tmpPT, tmpId_switches, tmpFragments = [0] * 5 + n_ignored_tr = 0 + for g, ign_g in zip(seq_trajectories.values(), + seq_ignored.values()): + # all frames of this gt trajectory are ignored + if all(ign_g): + n_ignored_tr += 1 + n_ignored_tr_total += 1 + continue + # all frames of this gt trajectory are not assigned to any detections + if all([this == -1 for this in g]): + tmpML += 1 + self.ML += 1 + continue + # compute tracked frames in trajectory + last_id = g[0] + # first detection (necessary to be in gt_trajectories) is always tracked + tracked = 1 if g[0] >= 0 else 0 + lgt = 0 if ign_g[0] else 1 + for f in range(1, len(g)): + if ign_g[f]: + last_id = -1 + continue + lgt += 1 + if last_id != g[f] and last_id != -1 and g[f] != -1 and g[ + f - 1] != -1: + tmpId_switches += 1 + self.id_switches += 1 + if f < len(g) - 1 and g[f - 1] != g[ + f] and last_id != -1 and g[f] != -1 and g[f + + 1] != -1: + tmpFragments += 1 + self.fragments += 1 + if g[f] != -1: + tracked += 1 + last_id = g[f] + # handle last frame; tracked state is handled in for loop (g[f]!=-1) + if len(g) > 1 and g[f - 1] != g[f] and last_id != -1 and g[ + f] != -1 and not ign_g[f]: + tmpFragments += 1 + self.fragments += 1 + + # compute MT/PT/ML + tracking_ratio = tracked / float(len(g) - sum(ign_g)) + if tracking_ratio > 0.8: + tmpMT += 1 + self.MT += 1 + elif tracking_ratio < 0.2: + tmpML += 1 + self.ML += 1 + else: # 0.2 <= tracking_ratio <= 0.8 + tmpPT += 1 + self.PT += 1 + + if (self.n_gt_trajectories - n_ignored_tr_total) == 0: + self.MT = 0. + self.PT = 0. + self.ML = 0. + else: + self.MT /= float(self.n_gt_trajectories - n_ignored_tr_total) + self.PT /= float(self.n_gt_trajectories - n_ignored_tr_total) + self.ML /= float(self.n_gt_trajectories - n_ignored_tr_total) + + # precision/recall etc. + if (self.fp + self.tp) == 0 or (self.tp + self.fn) == 0: + self.recall = 0. + self.precision = 0. + else: + self.recall = self.tp / float(self.tp + self.fn) + self.precision = self.tp / float(self.fp + self.tp) + if (self.recall + self.precision) == 0: + self.F1 = 0. + else: + self.F1 = 2. * (self.precision * self.recall) / ( + self.precision + self.recall) + if sum(self.n_frames) == 0: + self.FAR = "n/a" + else: + self.FAR = self.fp / float(sum(self.n_frames)) + + # compute CLEARMOT + if self.n_gt == 0: + self.MOTA = -float("inf") + self.MODA = -float("inf") + else: + self.MOTA = 1 - (self.fn + self.fp + self.id_switches + ) / float(self.n_gt) + self.MODA = 1 - (self.fn + self.fp) / float(self.n_gt) + if self.tp == 0: + self.MOTP = float("inf") + else: + self.MOTP = self.total_cost / float(self.tp) + if self.n_gt != 0: + if self.id_switches == 0: + self.MOTAL = 1 - (self.fn + self.fp + self.id_switches + ) / float(self.n_gt) + else: + self.MOTAL = 1 - (self.fn + self.fp + + math.log10(self.id_switches) + ) / float(self.n_gt) + else: + self.MOTAL = -float("inf") + if sum(self.n_frames) == 0: + self.MODP = "n/a" + else: + self.MODP = sum(self.MODP_t) / float(sum(self.n_frames)) + return True + + def createSummary(self): + summary = "" + summary += "tracking evaluation summary".center(80, "=") + "\n" + summary += self.printEntry("Multiple Object Tracking Accuracy (MOTA)", + self.MOTA) + "\n" + summary += self.printEntry("Multiple Object Tracking Precision (MOTP)", + self.MOTP) + "\n" + summary += self.printEntry("Multiple Object Tracking Accuracy (MOTAL)", + self.MOTAL) + "\n" + summary += self.printEntry("Multiple Object Detection Accuracy (MODA)", + self.MODA) + "\n" + summary += self.printEntry("Multiple Object Detection Precision (MODP)", + self.MODP) + "\n" + summary += "\n" + summary += self.printEntry("Recall", self.recall) + "\n" + summary += self.printEntry("Precision", self.precision) + "\n" + summary += self.printEntry("F1", self.F1) + "\n" + summary += self.printEntry("False Alarm Rate", self.FAR) + "\n" + summary += "\n" + summary += self.printEntry("Mostly Tracked", self.MT) + "\n" + summary += self.printEntry("Partly Tracked", self.PT) + "\n" + summary += self.printEntry("Mostly Lost", self.ML) + "\n" + summary += "\n" + summary += self.printEntry("True Positives", self.tp) + "\n" + #summary += self.printEntry("True Positives per Sequence", self.tps) + "\n" + summary += self.printEntry("Ignored True Positives", self.itp) + "\n" + #summary += self.printEntry("Ignored True Positives per Sequence", self.itps) + "\n" + + summary += self.printEntry("False Positives", self.fp) + "\n" + #summary += self.printEntry("False Positives per Sequence", self.fps) + "\n" + summary += self.printEntry("False Negatives", self.fn) + "\n" + #summary += self.printEntry("False Negatives per Sequence", self.fns) + "\n" + summary += self.printEntry("ID-switches", self.id_switches) + "\n" + self.fp = self.fp / self.n_gt + self.fn = self.fn / self.n_gt + self.id_switches = self.id_switches / self.n_gt + summary += self.printEntry("False Positives Ratio", self.fp) + "\n" + #summary += self.printEntry("False Positives per Sequence", self.fps) + "\n" + summary += self.printEntry("False Negatives Ratio", self.fn) + "\n" + #summary += self.printEntry("False Negatives per Sequence", self.fns) + "\n" + summary += self.printEntry("Ignored False Negatives Ratio", + self.ifn) + "\n" + + #summary += self.printEntry("Ignored False Negatives per Sequence", self.ifns) + "\n" + summary += self.printEntry("Missed Targets", self.fn) + "\n" + summary += self.printEntry("ID-switches", self.id_switches) + "\n" + summary += self.printEntry("Fragmentations", self.fragments) + "\n" + summary += "\n" + summary += self.printEntry("Ground Truth Objects (Total)", self.n_gt + + self.n_igt) + "\n" + #summary += self.printEntry("Ground Truth Objects (Total) per Sequence", self.n_gts) + "\n" + summary += self.printEntry("Ignored Ground Truth Objects", + self.n_igt) + "\n" + #summary += self.printEntry("Ignored Ground Truth Objects per Sequence", self.n_igts) + "\n" + summary += self.printEntry("Ground Truth Trajectories", + self.n_gt_trajectories) + "\n" + summary += "\n" + summary += self.printEntry("Tracker Objects (Total)", self.n_tr) + "\n" + #summary += self.printEntry("Tracker Objects (Total) per Sequence", self.n_trs) + "\n" + summary += self.printEntry("Ignored Tracker Objects", self.n_itr) + "\n" + #summary += self.printEntry("Ignored Tracker Objects per Sequence", self.n_itrs) + "\n" + summary += self.printEntry("Tracker Trajectories", + self.n_tr_trajectories) + "\n" + #summary += "\n" + #summary += self.printEntry("Ignored Tracker Objects with Associated Ignored Ground Truth Objects", self.n_igttr) + "\n" + summary += "=" * 80 + return summary + + def printEntry(self, key, val, width=(70, 10)): + """ + Pretty print an entry in a table fashion. + """ + s_out = key.ljust(width[0]) + if type(val) == int: + s = "%%%dd" % width[1] + s_out += s % val + elif type(val) == float: + s = "%%%df" % (width[1]) + s_out += s % val + else: + s_out += ("%s" % val).rjust(width[1]) + return s_out + + def saveToStats(self, save_summary): + """ + Save the statistics in a whitespace separate file. + """ + summary = self.createSummary() + if save_summary: + filename = os.path.join(self.result_path, + "summary_%s.txt" % self.cls) + dump = open(filename, "w+") + dump.write(summary) + dump.close() + return summary + + +class KITTIMOTMetric(Metric): + def __init__(self, save_summary=True): + self.save_summary = save_summary + self.MOTEvaluator = KITTIEvaluation + self.result_root = None + self.reset() + + def reset(self): + self.seqs = [] + self.n_sequences = 0 + self.n_frames = [] + self.strsummary = '' + + def update(self, data_root, seq, data_type, result_root, result_filename): + assert data_type == 'kitti', "data_type should 'kitti'" + self.result_root = result_root + self.gt_path = data_root + gt_path = '{}/label_02/{}.txt'.format(data_root, seq) + gt = open(gt_path, "r") + max_frame = 0 + for line in gt: + line = line.strip() + line_list = line.split(" ") + if int(line_list[0]) > max_frame: + max_frame = int(line_list[0]) + rs = open(result_filename, "r") + for line in rs: + line = line.strip() + line_list = line.split(" ") + if int(line_list[0]) > max_frame: + max_frame = int(line_list[0]) + gt.close() + rs.close() + self.n_frames.append(max_frame + 1) + self.seqs.append(seq) + self.n_sequences += 1 + + def accumulate(self): + logger.info("Processing Result for KITTI Tracking Benchmark") + e = self.MOTEvaluator(result_path=self.result_root, gt_path=self.gt_path,\ + n_frames=self.n_frames, seqs=self.seqs, n_sequences=self.n_sequences) + try: + if not e.loadTracker(): + return + logger.info("Loading Results - Success") + logger.info("Evaluate Object Class: %s" % c.upper()) + except: + logger.info("Caught exception while loading result data.") + if not e.loadGroundtruth(): + raise ValueError("Ground truth not found.") + logger.info("Loading Groundtruth - Success") + # sanity checks + if len(e.groundtruth) is not len(e.tracker): + logger.info( + "The uploaded data does not provide results for every sequence.") + return False + logger.info("Loaded %d Sequences." % len(e.groundtruth)) + logger.info("Start Evaluation...") + + if e.compute3rdPartyMetrics(): + self.strsummary = e.saveToStats(self.save_summary) + else: + logger.info( + "There seem to be no true positives or false positives at all in the submitted data." + ) + + def log(self): + print(self.strsummary) + + def get_results(self): + return self.strsummary diff --git a/ppdet/metrics/munkres.py b/ppdet/metrics/munkres.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd4a92d2a793bf130c8a1d253bd45bde8cbb0d1 --- /dev/null +++ b/ppdet/metrics/munkres.py @@ -0,0 +1,428 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is borrow from https://github.com/xingyizhou/CenterTrack/blob/master/src/tools/eval_kitti_track/munkres.py +""" + +import sys + +__all__ = ['Munkres', 'make_cost_matrix'] + + +class Munkres: + """ + Calculate the Munkres solution to the classical assignment problem. + See the module documentation for usage. + """ + + def __init__(self): + """Create a new instance""" + self.C = None + self.row_covered = [] + self.col_covered = [] + self.n = 0 + self.Z0_r = 0 + self.Z0_c = 0 + self.marked = None + self.path = None + + def make_cost_matrix(profit_matrix, inversion_function): + """ + **DEPRECATED** + + Please use the module function ``make_cost_matrix()``. + """ + import munkres + return munkres.make_cost_matrix(profit_matrix, inversion_function) + + make_cost_matrix = staticmethod(make_cost_matrix) + + def pad_matrix(self, matrix, pad_value=0): + """ + Pad a possibly non-square matrix to make it square. + + :Parameters: + matrix : list of lists + matrix to pad + + pad_value : int + value to use to pad the matrix + + :rtype: list of lists + :return: a new, possibly padded, matrix + """ + max_columns = 0 + total_rows = len(matrix) + + for row in matrix: + max_columns = max(max_columns, len(row)) + + total_rows = max(max_columns, total_rows) + + new_matrix = [] + for row in matrix: + row_len = len(row) + new_row = row[:] + if total_rows > row_len: + # Row too short. Pad it. + new_row += [0] * (total_rows - row_len) + new_matrix += [new_row] + + while len(new_matrix) < total_rows: + new_matrix += [[0] * total_rows] + + return new_matrix + + def compute(self, cost_matrix): + """ + Compute the indexes for the lowest-cost pairings between rows and + columns in the database. Returns a list of (row, column) tuples + that can be used to traverse the matrix. + + :Parameters: + cost_matrix : list of lists + The cost matrix. If this cost matrix is not square, it + will be padded with zeros, via a call to ``pad_matrix()``. + (This method does *not* modify the caller's matrix. It + operates on a copy of the matrix.) + + **WARNING**: This code handles square and rectangular + matrices. It does *not* handle irregular matrices. + + :rtype: list + :return: A list of ``(row, column)`` tuples that describe the lowest + cost path through the matrix + + """ + self.C = self.pad_matrix(cost_matrix) + self.n = len(self.C) + self.original_length = len(cost_matrix) + self.original_width = len(cost_matrix[0]) + self.row_covered = [False for i in range(self.n)] + self.col_covered = [False for i in range(self.n)] + self.Z0_r = 0 + self.Z0_c = 0 + self.path = self.__make_matrix(self.n * 2, 0) + self.marked = self.__make_matrix(self.n, 0) + + done = False + step = 1 + + steps = { + 1: self.__step1, + 2: self.__step2, + 3: self.__step3, + 4: self.__step4, + 5: self.__step5, + 6: self.__step6 + } + + while not done: + try: + func = steps[step] + step = func() + except KeyError: + done = True + + # Look for the starred columns + results = [] + for i in range(self.original_length): + for j in range(self.original_width): + if self.marked[i][j] == 1: + results += [(i, j)] + + return results + + def __copy_matrix(self, matrix): + """Return an exact copy of the supplied matrix""" + return copy.deepcopy(matrix) + + def __make_matrix(self, n, val): + """Create an *n*x*n* matrix, populating it with the specific value.""" + matrix = [] + for i in range(n): + matrix += [[val for j in range(n)]] + return matrix + + def __step1(self): + """ + For each row of the matrix, find the smallest element and + subtract it from every element in its row. Go to Step 2. + """ + C = self.C + n = self.n + for i in range(n): + minval = min(self.C[i]) + # Find the minimum value for this row and subtract that minimum + # from every element in the row. + for j in range(n): + self.C[i][j] -= minval + + return 2 + + def __step2(self): + """ + Find a zero (Z) in the resulting matrix. If there is no starred + zero in its row or column, star Z. Repeat for each element in the + matrix. Go to Step 3. + """ + n = self.n + for i in range(n): + for j in range(n): + if (self.C[i][j] == 0) and \ + (not self.col_covered[j]) and \ + (not self.row_covered[i]): + self.marked[i][j] = 1 + self.col_covered[j] = True + self.row_covered[i] = True + + self.__clear_covers() + return 3 + + def __step3(self): + """ + Cover each column containing a starred zero. If K columns are + covered, the starred zeros describe a complete set of unique + assignments. In this case, Go to DONE, otherwise, Go to Step 4. + """ + n = self.n + count = 0 + for i in range(n): + for j in range(n): + if self.marked[i][j] == 1: + self.col_covered[j] = True + count += 1 + + if count >= n: + step = 7 # done + else: + step = 4 + + return step + + def __step4(self): + """ + Find a noncovered zero and prime it. If there is no starred zero + in the row containing this primed zero, Go to Step 5. Otherwise, + cover this row and uncover the column containing the starred + zero. Continue in this manner until there are no uncovered zeros + left. Save the smallest uncovered value and Go to Step 6. + """ + step = 0 + done = False + row = -1 + col = -1 + star_col = -1 + while not done: + (row, col) = self.__find_a_zero() + if row < 0: + done = True + step = 6 + else: + self.marked[row][col] = 2 + star_col = self.__find_star_in_row(row) + if star_col >= 0: + col = star_col + self.row_covered[row] = True + self.col_covered[col] = False + else: + done = True + self.Z0_r = row + self.Z0_c = col + step = 5 + + return step + + def __step5(self): + """ + Construct a series of alternating primed and starred zeros as + follows. Let Z0 represent the uncovered primed zero found in Step 4. + Let Z1 denote the starred zero in the column of Z0 (if any). + Let Z2 denote the primed zero in the row of Z1 (there will always + be one). Continue until the series terminates at a primed zero + that has no starred zero in its column. Unstar each starred zero + of the series, star each primed zero of the series, erase all + primes and uncover every line in the matrix. Return to Step 3 + """ + count = 0 + path = self.path + path[count][0] = self.Z0_r + path[count][1] = self.Z0_c + done = False + while not done: + row = self.__find_star_in_col(path[count][1]) + if row >= 0: + count += 1 + path[count][0] = row + path[count][1] = path[count - 1][1] + else: + done = True + + if not done: + col = self.__find_prime_in_row(path[count][0]) + count += 1 + path[count][0] = path[count - 1][0] + path[count][1] = col + + self.__convert_path(path, count) + self.__clear_covers() + self.__erase_primes() + return 3 + + def __step6(self): + """ + Add the value found in Step 4 to every element of each covered + row, and subtract it from every element of each uncovered column. + Return to Step 4 without altering any stars, primes, or covered + lines. + """ + minval = self.__find_smallest() + for i in range(self.n): + for j in range(self.n): + if self.row_covered[i]: + self.C[i][j] += minval + if not self.col_covered[j]: + self.C[i][j] -= minval + return 4 + + def __find_smallest(self): + """Find the smallest uncovered value in the matrix.""" + minval = 2e9 # sys.maxint + for i in range(self.n): + for j in range(self.n): + if (not self.row_covered[i]) and (not self.col_covered[j]): + if minval > self.C[i][j]: + minval = self.C[i][j] + return minval + + def __find_a_zero(self): + """Find the first uncovered element with value 0""" + row = -1 + col = -1 + i = 0 + n = self.n + done = False + + while not done: + j = 0 + while True: + if (self.C[i][j] == 0) and \ + (not self.row_covered[i]) and \ + (not self.col_covered[j]): + row = i + col = j + done = True + j += 1 + if j >= n: + break + i += 1 + if i >= n: + done = True + + return (row, col) + + def __find_star_in_row(self, row): + """ + Find the first starred element in the specified row. Returns + the column index, or -1 if no starred element was found. + """ + col = -1 + for j in range(self.n): + if self.marked[row][j] == 1: + col = j + break + + return col + + def __find_star_in_col(self, col): + """ + Find the first starred element in the specified row. Returns + the row index, or -1 if no starred element was found. + """ + row = -1 + for i in range(self.n): + if self.marked[i][col] == 1: + row = i + break + + return row + + def __find_prime_in_row(self, row): + """ + Find the first prime element in the specified row. Returns + the column index, or -1 if no starred element was found. + """ + col = -1 + for j in range(self.n): + if self.marked[row][j] == 2: + col = j + break + + return col + + def __convert_path(self, path, count): + for i in range(count + 1): + if self.marked[path[i][0]][path[i][1]] == 1: + self.marked[path[i][0]][path[i][1]] = 0 + else: + self.marked[path[i][0]][path[i][1]] = 1 + + def __clear_covers(self): + """Clear all covered matrix cells""" + for i in range(self.n): + self.row_covered[i] = False + self.col_covered[i] = False + + def __erase_primes(self): + """Erase all prime markings""" + for i in range(self.n): + for j in range(self.n): + if self.marked[i][j] == 2: + self.marked[i][j] = 0 + + +def make_cost_matrix(profit_matrix, inversion_function): + """ + Create a cost matrix from a profit matrix by calling + 'inversion_function' to invert each value. The inversion + function must take one numeric argument (of any type) and return + another numeric argument which is presumed to be the cost inverse + of the original profit. + + This is a static method. Call it like this: + + .. python:: + + cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func) + + For example: + + .. python:: + + cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxint - x) + + :Parameters: + profit_matrix : list of lists + The matrix to convert from a profit to a cost matrix + + inversion_function : function + The function to use to invert each entry in the profit matrix + + :rtype: list of lists + :return: The converted matrix + """ + cost_matrix = [] + for row in profit_matrix: + cost_matrix.append([inversion_function(value) for value in row]) + return cost_matrix diff --git a/tools/eval_mot.py b/tools/eval_mot.py index 9d4a9b22a0dcda4b8353bebbbec6486f1416061e..4e03eaf1b0f6b61e74a460dabf605daf76932e4e 100644 --- a/tools/eval_mot.py +++ b/tools/eval_mot.py @@ -41,11 +41,6 @@ logger = setup_logger('eval') def parse_args(): parser = ArgsParser() - parser.add_argument( - "--data_type", - type=str, - default='mot', - help='Data type of tracking dataset, should be in ["mot", "kitti"]') parser.add_argument( "--det_results_dir", type=str, @@ -95,7 +90,7 @@ def run(FLAGS, cfg): tracker.mot_evaluate( data_root=data_root, seqs=seqs, - data_type=FLAGS.data_type, + data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir, save_images=FLAGS.save_images, diff --git a/tools/infer_mot.py b/tools/infer_mot.py index 57d7e6dff40ddfb2bc3e952e59722ea1db7d8252..16d1f193688b12d7fac9b9deaa2336cde95a6462 100644 --- a/tools/infer_mot.py +++ b/tools/infer_mot.py @@ -48,11 +48,6 @@ def parse_args(): type=str, default=None, help="Directory for images to perform inference on.") - parser.add_argument( - "--data_type", - type=str, - default='mot', - help='Data type of tracking dataset, should be in ["mot", "kitti"]') parser.add_argument( "--det_results_dir", type=str, @@ -101,7 +96,7 @@ def run(FLAGS, cfg): tracker.mot_predict( video_file=FLAGS.video_file, image_dir=FLAGS.image_dir, - data_type=FLAGS.data_type, + data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir, save_images=FLAGS.save_images,