未验证 提交 a5a1c192 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1851 from zhangxinyu-xyz/ISE_ReID

Release the inference code of ISE (ReID-CVPR2022)
# ISE
---
## Catalogue
- [1. Introduction](#1)
- [2. Performance on Market1501 and MSMT17](#2)
- [3. Test](#3)
- [4. Reference](#4)
<a name='1'></a>
## 1. Introduction
ISE (Implicit Sample Extension) is a simple, efficient, and effective learning algorithm for unsupervised person Re-ID. ISE generates what we call support samples around the cluster boundaries. The sample generation process in ISE depends on two critical mechanisms, i.e., a progressive linear interpolation strategy and a label-preserving loss function. The generated support samples from ISE provide complementary information, which can nicely handle the "sub and mixed" clustering errors. ISE achieves superior performance than other unsupervised methods on Market1501 and MSMT17 datasets.
> [**Implicit Sample Extension for Unsupervised Person Re-Identification**](https://arxiv.org/abs/2204.06892v1)<br>
> Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang<br>
> CVPR2022
![image](../../images/ISE_ReID/ISE_pipeline.png)
<a name='2'></a>
## 2. Performance on Market1501 and MSMT17
The main results on Market1501 (M) and MSMT17 (MS). PIL denotes the progressive linear interpolation strategy. LP represents the label-preserving loss function.
| Methods | M | Link | MS | Link |
| --- | -- | -- | -- | - |
| Baseline | 82.5 (92.5) | - | 30.1 (58.6) | - |
| ISE (+PIL) | 83.9 (93.9) | - | 33.5 (63.9) | - |
| ISE (+LP) | 83.6 (92.7) | - | 31.4 (59.9) | - |
| ISE (Ours) (+PIL+LP) | **84.7 (94.0)** | [ISE_M](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_M_model.pdparams) | **35.0 (64.7)** | [ISE_MS](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_MS_model.pdparams) |
<a name="3"></a>
## 3. Test
The training code is coming soon. We first release the test code with the pretrained models.
**Test:** You can simply run the following script for the evaluation.
```
python tools/eval.py -c ./ppcls/configs/Person/ResNet50_UReID_infer.yaml
```
**Steps:**
1. Download the pretrained model first, and put the model into: ```./pd_model_trace/ISE/```.
2. Change the dataset name in: ```./ppcls/configs/Person/ResNet50_UReID_infer.yaml```.
3. Run the above script.
<a name="4"></a>
## 4. Reference
If you find ISE useful in your research, please kindly consider citing our paper:
```
@inproceedings{zhang2022Implicit,
title={Implicit Sample Extension for Unsupervised Person Re-Identification},
author={Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
```
# ISE
---
## 目录
- [1. 介绍](#1)
- [2. 在Market1501和MSMT17上的结果](#2)
- [3. 测试](#3)
- [4. 引用](#4)
<a name='1'></a>
## 1. 介绍
ISE (Implicit Sample Extension)是一种简单、高效、有效的无监督行人再识别学习算法。ISE在聚类蔟边界周围生成样本,我们称之为支持样本。ISE的样本生成过程依赖于两个关键机制,即渐进线性插值策略(progressive linear interpolation)和标签保留的损失函数(label-preserving loss function)。ISE生成的支持样本提供了额外补充信息,可以很好地处理“子类和混合”的聚类错误。ISE在Market1501和MSMT17数据集上取得了优于其他无监督方法的性能。
> [**Implicit Sample Extension for Unsupervised Person Re-Identification**](https://arxiv.org/abs/2204.06892v1)<br>
> Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang<br>
> CVPR2022
![image](../../images/ISE_ReID/ISE_pipeline.png)
<a name='2'></a>
## 2. 在Market1501和MSMT17上的结果
在Market1501和MSMT17上的主要结果。“PIL”表示渐进线性插值策略。“LP”表示标签保留的损失函数。
| 方法 | Market1501 | 下载链接 | MSMT17 | 下载链接 |
| --- | -- | -- | -- | - |
| Baseline | 82.5 (92.5) | - | 30.1 (58.6) | - |
| ISE (+PIL) | 83.9 (93.9) | - | 33.5 (63.9) | - |
| ISE (+LP) | 83.6 (92.7) | - | 31.4 (59.9) | - |
| ISE (Ours) (+PIL+LP) | **84.7 (94.0)** | [ISE_M](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_M_model.pdparams) | **35.0 (64.7)** | [ISE_MS](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ISE_MS_model.pdparams) |
<a name="3"></a>
## 3. 测试
我们很快会提供训练代码,首先我们提供了测试代码和模型。
**测试:** 可简使用如下脚本进行模型评估。
```
python tools/eval.py -c ./ppcls/configs/Person/ResNet50_UReID_infer.yaml
```
**步骤:**
1. 首先下载模型,并放入:```./pd_model_trace/ISE/```
2. 改变```./ppcls/configs/Person/ResNet50_UReID_infer.yaml```中的数据集名称。
3. 运行上述脚本。
<a name="4"></a>
## 4. 引用
如果ISE在您的研究中有启发,请考虑引用我们的论文:
```
@inproceedings{zhang2022Implicit,
title={Implicit Sample Extension for Unsupervised Person Re-Identification},
author={Xinyu Zhang, Dongdong Li, Zhigang Wang, Jian Wang, Errui Ding, Javen Qinfeng Shi, Zhaoxiang Zhang, Jingdong Wang},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2022}
}
```
......@@ -18,13 +18,15 @@ from .circlemargin import CircleMargin
from .fc import FC
from .vehicle_neck import VehicleNeck
from paddle.nn import Tanh
from .bnneck import BNNeck
__all__ = ['build_gear']
def build_gear(config):
support_dict = [
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh'
'ArcMargin', 'CosMargin', 'CircleMargin', 'FC', 'VehicleNeck', 'Tanh',
'BNNeck'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
class BNNeck(nn.Layer):
def __init__(self, num_features):
super().__init__()
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0))
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
trainable=False)
self.feat_bn = nn.BatchNorm1D(
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.flatten(x)
x = self.feat_bn(x)
return x
# global configs
Global:
checkpoints: null
pretrained_model: null
# pretrained_model: "./pd_model_trace/ISE/ISE_M_model" # pretrained ISE model for Market1501
# pretrained_model: "./pd_model_trace/ISE/ISE_MS_model" # pretrained ISE model for MSMT17
output_dir: "./output/"
device: "gpu"
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 128, 256]
save_inference_dir: "./inference"
eval_mode: "retrieval"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
BackboneStopLayer:
name: "avg_pool"
Neck:
name: "BNNeck"
num_features: 2048
Head:
name: "FC"
embedding_size: 2048
class_num: 751
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- SupConLoss:
weight: 1.0
views: 2
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.04
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501" # ["Market1501", "MSMT17"]
image_root: "./dataset"
cls_label_path: "bounding_box_train"
transform_ops:
- ResizeImage:
size: [128, 256]
interpolation: 'bicubic'
backend: 'pil'
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
fill: 0
- RandomCrop:
size: [128, 256]
pad_if_needed: False
- NormalizeImage:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.485, 0.456, 0.406]
sampler:
name: PKSampler
batch_size: 16
sample_per_id: 4
drop_last: True
shuffle: True
loader:
num_workers: 6
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501" # ["Market1501", "MSMT17"]
image_root: "./dataset"
cls_label_path: "query"
transform_ops:
- ResizeImage:
size: [128, 256]
interpolation: 'bicubic'
backend: 'pil'
- NormalizeImage:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Gallery:
dataset:
name: "Market1501" # ["Market1501", "MSMT17"]
image_root: "./dataset"
cls_label_path: "bounding_box_test"
transform_ops:
- ResizeImage:
size: [128, 256]
interpolation: 'bicubic'
backend: 'pil'
- NormalizeImage:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 6
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
......@@ -28,6 +28,7 @@ from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
......@@ -7,3 +7,4 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
from paddle.io import Dataset
import os
import cv2
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
from ppcls.utils import logger
from .common_dataset import create_operators
import os.path as osp
import glob
import re
from PIL import Image
class Market1501(Dataset):
"""
Market1501
Reference:
Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.
URL: http://www.liangzheng.org/Project/project_reid.html
Dataset statistics:
# identities: 1501 (+1 for background)
# images: 12936 (train) + 3368 (query) + 15913 (gallery)
"""
_dataset_dir = 'market1501/Market-1501-v15.09.15'
def __init__(self, image_root, cls_label_path, transform_ops=None):
self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir,
self._cls_path)
self._check_before_run()
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False)
def _check_before_run(self):
"""Check if the file is available before going deeper"""
if not osp.exists(self._dataset_dir):
raise RuntimeError("'{}' is not available".format(
self._dataset_dir))
def _load_anno(self, relabel=False):
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
self.images = []
self.labels = []
self.cameras = []
pid_container = set()
for img_path in sorted(img_paths):
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1: continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
for img_path in sorted(img_paths):
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1: continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= camid <= 6
camid -= 1 # index starts from 0
if relabel: pid = pid2label[pid]
self.images.append(img_path)
self.labels.append(pid)
self.cameras.append(camid)
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
self.images, self.labels, self.cameras, subfolder=self._cls_path)
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
class MSMT17(Dataset):
"""
MSMT17
Reference:
Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018.
URL: http://www.pkuvmc.com/publications/msmt17.html
Dataset statistics:
# identities: 4101
# images: 32621 (train) + 11659 (query) + 82161 (gallery)
# cameras: 15
"""
_dataset_dir = 'msmt17/MSMT17_V1'
def __init__(self, image_root, cls_label_path, transform_ops=None):
self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir,
self._cls_path)
self._check_before_run()
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False)
def _check_before_run(self):
"""Check if the file is available before going deeper"""
if not osp.exists(self._dataset_dir):
raise RuntimeError("'{}' is not available".format(
self._dataset_dir))
def _load_anno(self, relabel=False):
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d+)')
self.images = []
self.labels = []
self.cameras = []
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
assert 1 <= camid <= 15
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
self.images.append(img_path)
self.labels.append(pid)
self.cameras.append(camid)
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
self.images, self.labels, self.cameras, subfolder=self._cls_path)
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
def get_imagedata_info(data, labels, cameras, subfolder='train'):
pids, cams = [], []
for _, pid, camid in zip(data, labels, cameras):
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
print("Dataset statistics:")
print(" ----------------------------------------")
print(" subset | # ids | # images | # cameras")
print(" ----------------------------------------")
print(" {} | {:5d} | {:8d} | {:9d}".format(subfolder, num_pids,
num_imgs, num_cams))
print(" ----------------------------------------")
return num_pids, num_imgs, num_cams
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册