提交 458fd345 编写于 作者: H HydrogenSulfate

polish reid doc

上级 d3c4b46e
......@@ -27,6 +27,7 @@ import cv2
import numpy as np
import importlib
from PIL import Image
from paddle.vision.transforms import ToTensor, Normalize
from python.det_preprocess import DetNormalizeImage, DetPadStride, DetPermute, DetResize
......@@ -53,13 +54,14 @@ def create_operators(params):
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2"):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'area': cv2.INTER_AREA,
'bicubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4
'lanczos': cv2.INTER_LANCZOS4,
'random': (cv2.INTER_LINEAR, cv2.INTER_CUBIC)
}
_pil_interp_from_str = {
'nearest': Image.NEAREST,
......@@ -67,13 +69,26 @@ class UnifiedResize(object):
'bicubic': Image.BICUBIC,
'box': Image.BOX,
'lanczos': Image.LANCZOS,
'hamming': Image.HAMMING
'hamming': Image.HAMMING,
'random': (Image.BILINEAR, Image.BICUBIC)
}
def _pil_resize(src, size, resample):
def _cv2_resize(src, size, resample):
if isinstance(resample, tuple):
resample = random.choice(resample)
return cv2.resize(src, size, interpolation=resample)
def _pil_resize(src, size, resample, return_numpy=True):
if isinstance(resample, tuple):
resample = random.choice(resample)
if isinstance(src, np.ndarray):
pil_img = Image.fromarray(src)
else:
pil_img = src
pil_img = pil_img.resize(size, resample)
if return_numpy:
return np.asarray(pil_img)
return pil_img
if backend.lower() == "cv2":
if isinstance(interpolation, str):
......@@ -81,11 +96,12 @@ class UnifiedResize(object):
# compatible with opencv < version 4.4.0
elif interpolation is None:
interpolation = cv2.INTER_LINEAR
self.resize_func = partial(cv2.resize, interpolation=interpolation)
self.resize_func = partial(_cv2_resize, resample=interpolation)
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
self.resize_func = partial(_pil_resize, resample=interpolation)
self.resize_func = partial(
_pil_resize, resample=interpolation, return_numpy=return_numpy)
else:
logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
......@@ -93,6 +109,8 @@ class UnifiedResize(object):
self.resize_func = cv2.resize
def __call__(self, src, size):
if isinstance(size, list):
size = tuple(size)
return self.resize_func(src, size)
......@@ -137,7 +155,8 @@ class ResizeImage(object):
size=None,
resize_short=None,
interpolation=None,
backend="cv2"):
backend="cv2",
return_numpy=True):
if resize_short is not None and resize_short > 0:
self.resize_short = resize_short
self.w = None
......@@ -151,10 +170,16 @@ class ResizeImage(object):
'both 'size' and 'resize_short' are None")
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
interpolation=interpolation,
backend=backend,
return_numpy=return_numpy)
def __call__(self, img):
if isinstance(img, np.ndarray):
img_h, img_w = img.shape[:2]
else:
img_w, img_h = img.size
if self.resize_short is not None:
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
......
......@@ -3,14 +3,18 @@
## 目录
- [ReID行人重识别](#reid行人重识别)
- [1. 算法/应用场景简介](#1-算法应用场景简介)
- [2. ReID算法](#2-reid算法)
- [2.1 ReID strong-baseline算法](#21-reid-strong-baseline算法)
- [2.2a 快速体验](#22a-快速体验)
- [2.2b 模型训练/推理等](#22b-模型训练推理等)
- [2.1 ReID strong-baseline](#21-reid-strong-baseline)
- [2.1.1 原理介绍](#211-原理介绍)
- [2.1.1 数据准备](#211-数据准备)
- [2.1.2 模型训练](#212-模型训练)
- [2.1.3 模型评估](#213-模型评估)
- [2.1.3 模型推理](#213-模型推理)
- [2.1.3.1 模型导出](#2131-模型导出)
- [2.1.3.2 模型推理](#2132-模型推理)
- [3. 总结](#3-总结)
- [3.1 方法总结、对比等](#31-方法总结对比等)
- [3.1 方法总结与对比](#31-方法总结与对比)
- [3.2 使用建议/FAQ](#32-使用建议faq)
- [4 参考资料](#4-参考资料)
......@@ -23,13 +27,15 @@
### 2. ReID算法
#### 2.1a ReID strong-baseline算法
#### 2.1 ReID strong-baseline
论文出处:[Bag of Tricks and A Strong Baseline for Deep Person Re-identification](https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf)
<img src="../../images/reid/strong-baseline.jpg" width="50%">
原理介绍:作者以普遍使用的基于 ResNet50 的行人重识别方法为基础,探索并总结了以下几种有效且适用性较强的优化方法,大幅度提高了在多个行人重识别数据集上的指标。
##### 2.1.1 原理介绍
作者以普遍使用的基于 ResNet50 的行人重识别模型为基础,探索并总结了以下几种有效且适用性较强的优化方法,大幅度提高了在多个行人重识别数据集上的指标。
1. Warmup:在训练一开始让学习率从一个较小值逐渐升高后再开始下降,有利于梯度下降优化时的稳定性,从而找到更优的参数模型。
2. Random erasing augmentation:随机区域擦除,通过数据增强来提升模型的泛化能力。
......@@ -39,11 +45,21 @@
6. Center loss:给每个类别一个可学习的聚类中心,训练时让类内特征靠近聚类中心,减少类内差异,增大类间差异。
7. Reranking:在检索时考虑查询图像的近邻候选对象,根据候选对象的近邻图像的是否也含有查询图像的情况来优化距离矩阵,最终提升检索精度。
#### 2.1b 快速体验
以下表格总结了复现的ReID strong-baseline的3种配置在 Market1501 数据集上的精度指标,
| 配置文件 | recall@1 | mAP | 参考recall@1 | 参考mAP |
| ------------------------ | -------- | ----- | ------------ | ------- |
| baseline.yaml | 88.21 | 74.12 | 87.7 | 74.0 |
| softmax.yaml | 94.18 | 85.76 | 94.1 | 85.7 |
| softmax_with_center.yaml | 94.19 | 85.80 | 94.1 | 85.7 |
注:上述参考指标由使用作者开源的代码在我们的设备上训练多次得到,由于系统环境、torch版本、CUDA版本不同等原因,与作者提供的指标可能存在略微差异。
接下来主要以`softmax_triplet_with_center.yaml`配置和训练好的模型文件为例,展示在 Market1501 数据集上进行训练、测试、推理的过程。
快速体验章节主要以`softmax_triplet_with_center.yaml`配置和训练好的模型文件为例,在 Market1501 数据集上进行测试。
##### 2.1.2 数据准备
1. 下载[Market-1501-v15.09.15.zip](https://pan.baidu.com/s/1ntIi2Op?_at_=1654142245770)数据集,解压到`PaddleClas/dataset/`下,并组织成以下文件结构:
下载 [Market-1501-v15.09.15.zip](https://pan.baidu.com/s/1ntIi2Op?_at_=1654142245770) 数据集,解压到`PaddleClas/dataset/`下,并组织成以下文件结构:
```shell
PaddleClas/dataset/market1501
......@@ -60,97 +76,91 @@
└── readme.txt
```
2. 下载 [reid_strong_baseline_softmax_with_center.epoch_120.pdparams](reid_strong_baseline_softmax_with_center.epoch_120.pdparams)`PaddleClas/pretrained_models` 文件夹中
##### 2.1.3 模型训练
```shell
cd PaddleClas
mkdir pretrained_models
cd pretrained_models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/reid_strong_baseline_softmax_with_center.epoch_120.pdparams
cd ..
```
3. 使用下载好的 `softmax_triplet_with_center.pdparams` 在 Market1501 数据集上进行测试
1. 执行以下命令开始训练
```shell
python3.7 tools/eval.py \
-c ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml \
-o Global.pretrained_model="pretrained_models/reid_strong_baseline_softmax_with_center.epoch_120"
python3.7 tools/train.py -c ./ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml
```
4. 查看输出结果
注:单卡训练大约需要1个小时。
```log
...
[2022/06/02 03:08:07] ppcls INFO: gallery feature calculation process: [0/125]
[2022/06/02 03:08:11] ppcls INFO: gallery feature calculation process: [20/125]
[2022/06/02 03:08:15] ppcls INFO: gallery feature calculation process: [40/125]
[2022/06/02 03:08:19] ppcls INFO: gallery feature calculation process: [60/125]
[2022/06/02 03:08:23] ppcls INFO: gallery feature calculation process: [80/125]
[2022/06/02 03:08:27] ppcls INFO: gallery feature calculation process: [100/125]
[2022/06/02 03:08:31] ppcls INFO: gallery feature calculation process: [120/125]
[2022/06/02 03:08:32] ppcls INFO: Build gallery done, all feat shape: [15913, 2048], begin to eval..
[2022/06/02 03:08:33] ppcls INFO: query feature calculation process: [0/27]
[2022/06/02 03:08:36] ppcls INFO: query feature calculation process: [20/27]
[2022/06/02 03:08:38] ppcls INFO: Build query done, all feat shape: [3368, 2048], begin to eval..
[2022/06/02 03:08:38] ppcls INFO: re_ranking=False
[2022/06/02 03:08:39] ppcls INFO: [Eval][Epoch 0][Avg]recall1: 0.94270, recall5: 0.98189, mAP: 0.85799
```
2. 查看训练日志和保存的模型参数文件
可以看到我们提供的 `reid_strong_baseline_softmax_with_center.epoch_120.pdparams` 模型在 Market1501 数据集上的指标为recall@1=0.94270,recall@5=0.98189,mAP=0.85799
训练过程中会在屏幕上实时打印loss等指标信息,同时会保存日志文件`train.log`、模型参数文件`*.pdparams`、优化器参数文件`*.pdopt`等内容到`Global.output_dir`指定的文件夹下,默认在`PaddleClas/output/RecModel/`文件夹下
#### 2.2c 模型训练/推理等
##### 2.1.4 模型评估
- 模型训练
准备用于评估的`*.pdparams`模型参数文件,可以使用训练好的模型,也可以使用[2.2 模型训练](#22-模型训练)中保存的模型。
1. 下载[Market-1501-v15.09.15.zip](https://pan.baidu.com/s/1ntIi2Op?_at_=1654142245770)数据集,解压到`PaddleClas/dataset/`下,并组织成以下文件结构:
- 以训练过程中保存的`latest.pdparams`为例,执行如下命令即可进行评估。
```shell
PaddleClas/dataset/market1501
└── Market-1501-v15.09.15/
├── bounding_box_test/
├── bounding_box_train/
├── gt_bbox/
├── gt_query/
├── query/
├── generate_anno.py
├── bounding_box_test.txt
├── bounding_box_train.txt
├── query.txt
└── readme.txt
python3.7 tools/eval.py \
-c ./ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml \
-o Global.pretrained_model="./output/RecModel/latest"
```
2. 执行以下命令开始训练
- 以训练好的模型为例,下载 [reid_strong_baseline_softmax_with_center.epoch_120.pdparams](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/reid_strong_baseline_softmax_with_center.epoch_120.pdparams)`PaddleClas/pretrained_models` 文件夹中,执行如下命令即可进行评估。
```shell
python3.7 tools/train.py -c ./ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml
# 下载模型
cd PaddleClas
mkdir pretrained_models
cd pretrained_models
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/pretrain/reid_strong_baseline_softmax_with_center.epoch_120.pdparams
cd ..
# 评估
python3.7 tools/eval.py \
-c ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml \
-o Global.pretrained_model="pretrained_models/reid_strong_baseline_softmax_with_center.epoch_120"
```
注:`pretrained_model` 后填入的地址不需要加 `.pdparams` 后缀,在程序运行时会自动补上。
注:单卡训练大约需要1个小时。
- 模型测试
- 查看输出结果
```log
...
...
ppcls INFO: gallery feature calculation process: [0/125]
ppcls INFO: gallery feature calculation process: [20/125]
ppcls INFO: gallery feature calculation process: [40/125]
ppcls INFO: gallery feature calculation process: [60/125]
ppcls INFO: gallery feature calculation process: [80/125]
ppcls INFO: gallery feature calculation process: [100/125]
ppcls INFO: gallery feature calculation process: [120/125]
ppcls INFO: Build gallery done, all feat shape: [15913, 2048], begin to eval..
ppcls INFO: query feature calculation process: [0/27]
ppcls INFO: query feature calculation process: [20/27]
ppcls INFO: Build query done, all feat shape: [3368, 2048], begin to eval..
ppcls INFO: re_ranking=False
ppcls INFO: [Eval][Epoch 0][Avg]recall1: 0.94270, recall5: 0.98189, mAP: 0.85799
```
默认评估日志保存在`PaddleClas/output/RecModel/eval.log`中,可以看到我们提供的 `reid_strong_baseline_softmax_with_center.epoch_120.pdparams` 模型在 Market1501 数据集上的评估指标为recall@1=0.94270,recall@5=0.98189,mAP=0.85799
假设需要测试的模型文件路径为 `./output/RecModel/latest.pdparams` ,执行下述命令即可进行测试
##### 2.1.5 模型推理
###### 2.1.5.1 模型导出
可以选择使用训练过程中保存的模型文件转换成 inference 模型并推理,或者使用我们提供的转换好的 inference 模型直接进行推理
- 将训练过程中保存的模型文件转换成 inference 模型,同样以`latest.pdparams`为例,执行以下命令进行转换
```shell
python3.7 tools/eval.py \
-c ./ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml \
-o Global.pretrained_model="./output/RecModel/latest"
python3.7 tools/export_model.py \
-c ppcls/configs/reid/strong_baseline/softmax_triplet_with_center.yaml \
-o Global.pretrained_model="output/RecModel/latest" \
-o Global.save_inference_dir="./deploy/reid_srong_baseline_softmax_with_center"
```
注:`pretrained_model` 后填入的地址不需要加 `.pdparams` 后缀,在程序运行时会自动补上。
- 模型推理
1. 下载 inference 模型并解压:[reid_srong_baseline_softmax_with_center.tar](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/reid_srong_baseline_softmax_with_center.tar)
- 或者下载并解压我们提供的 inference 模型
```shell
cd PaddleClas/deploy
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/reid/inference/reid_srong_baseline_softmax_with_center.tar
tar xf reid_srong_baseline_softmax_with_center.tar
cd ../
```
2. 修改 `PaddleClas/deploy/configs/inference_rec.yaml`。将 `infer_imgs:` 后的字段改为 Market1501 中 query 文件夹下的任意一张图片路径(下方代码使用的是`0294_c1s1_066631_00.jpg`图片路径);将 `rec_inference_model_dir:` 后的字段改为解压出来的 reid_srong_baseline_softmax_with_center 文件夹路径;将 `transform_ops` 字段下的预处理配置改为 `softmax_triplet_with_center.yaml``Eval.Query.dataset` 下的预处理配置。如下所示
###### 2.1.5.2 模型推理
1. 修改 `PaddleClas/deploy/configs/inference_rec.yaml`。将 `infer_imgs:` 后的字段改为 Market1501 中 query 文件夹下的任意一张图片路径(下方代码使用的是`0294_c1s1_066631_00.jpg`图片路径);将 `rec_inference_model_dir:` 后的字段改为解压出来的 reid_srong_baseline_softmax_with_center 文件夹路径;将 `transform_ops` 字段下的预处理配置改为 `softmax_triplet_with_center.yaml``Eval.Query.dataset` 下的预处理配置。如下所示
```yaml
Global:
......@@ -172,7 +182,7 @@
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......@@ -182,13 +192,14 @@
RecPostProcess: null
```
3. 执行推理命令
2. 执行推理命令
```shell
cd PaddleClas/deploy/
python3.7 python/predict_rec.py -c ./configs/inference_rec.yaml
```
4. 查看输出结果,实际结果为一个长度2048的向量,表示输入图片经过模型转换后得到的特征向量
3. 查看输出结果,实际结果为一个长度2048的向量,表示输入图片经过模型转换后得到的特征向量
```shell
0294_c1s1_066631_00.jpg: [ 0.01806974 0.00476423 -0.00508293 ... 0.03925538 0.00377574
......@@ -198,25 +209,17 @@
### 3. 总结
#### 3.1 方法总结、对比等
以下表格总结了我们提供的ReID strong-baseline的3种配置在 Market1501 数据集上的精度指标,
#### 3.1 方法总结与对比
| 配置文件 | recall@1 | mAP | 参考recall@1 | 参考mAP |
| ------------------------ | -------- | ----- | ------------ | ------- |
| baseline.yaml | 88.21 | 74.12 | 87.7 | 74.0 |
| softmax.yaml | 94.18 | 85.76 | 94.1 | 85.7 |
| softmax_with_center.yaml | 94.19 | 85.80 | 94.1 | 85.7 |
注:上述参考指标由使用作者开源的代码在我们的设备上训练多次得到,由于系统环境、torch版本、CUDA版本不同等原因,与作者提供的指标可能存在略微差异。
上述算法能快速地迁移至多数的ReID模型中,能进一步提升ReID模型的性能。
#### 3.2 使用建议/FAQ
Market1501 数据集比较小,可以尝试训练多次取最高精度。
#### 4 参考资料
### 4 参考资料
1. [Bag of Tricks and A Strong Baseline for Deep Person Re-identification](https://openaccess.thecvf.com/content_CVPRW_2019/papers/TRMTMCT/Luo_Bag_of_Tricks_and_a_Strong_Baseline_for_Deep_Person_CVPRW_2019_paper.pdf)
2. [michuanhaohao/reid-strong-baseline: Bag of Tricks and A Strong Baseline for Deep Person Re-identification (github.com)](https://github.com/michuanhaohao/reid-strong-baseline)
2. [michuanhaohao/reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline)
3. [行人重识别数据集之 Market1501 数据集_star_function的博客-CSDN博客_market1501数据集](https://blog.csdn.net/qq_39220334/article/details/121470106)
4. [Deep Learning for Person Re-identification:A Survey and Outlook](https://arxiv.org/abs/2001.04193)
......@@ -64,7 +64,7 @@ Optimizer:
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
name: "L2"
coeff: 0.0005
# data loader for train and eval
......@@ -79,7 +79,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- RandFlipImage:
flip_code: 1
......@@ -111,7 +111,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......@@ -136,7 +136,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......
......@@ -76,7 +76,7 @@ Optimizer:
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
name: "L2"
coeff: 0.0005
# data loader for train and eval
......@@ -91,7 +91,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- RandFlipImage:
flip_code: 1
......@@ -129,7 +129,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......@@ -154,7 +154,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......
......@@ -82,7 +82,7 @@ Optimizer:
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
name: "L2"
coeff: 0.0005
- SGD:
scope: CenterLoss
......@@ -102,7 +102,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- RandFlipImage:
flip_code: 1
......@@ -140,7 +140,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......@@ -165,7 +165,7 @@ DataLoader:
- ResizeImage:
size: [128, 256]
return_numpy: False
interpolation: 'bilinear'
interpolation: "bilinear"
backend: "pil"
- ToTensor:
- Normalize:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册