提交 1070d9be 编写于 作者: T tianyi1997 提交者: HydrogenSulfate

Create dataloader for MetaBIN

上级 adb99303
......@@ -32,10 +32,11 @@ from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17, DukeMTMC
from ppcls.data.dataloader.face_dataset import FiveValidationDataset, AdaFaceDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
......
......@@ -9,7 +9,8 @@ from ppcls.data.dataloader.multi_scale_dataset import MultiScaleDataset
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data.dataloader.multi_scale_sampler import MultiScaleSampler
from ppcls.data.dataloader.pk_sampler import PKSampler
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17
from ppcls.data.dataloader.person_dataset import Market1501, MSMT17, DukeMTMC
from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset
from ppcls.data.dataloader.custom_label_dataset import CustomLabelDataset
from ppcls.data.dataloader.cifar import Cifar10, Cifar100
from ppcls.data.dataloader.metabin_sampler import DomainShuffleBatchSampler, NaiveIdentityBatchSampler
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
from collections import defaultdict
import numpy as np
from paddle.io import Sampler, BatchSampler
class DomainShuffleSampler(Sampler):
"""
Domain shuffle sampler
Code was heavily based on https://github.com/bismex/MetaBIN
reference: https://arxiv.org/abs/2011.14670v2
"""
def __init__(self,
dataset: str,
batch_size: int,
num_instances: int,
camera_to_domain=True):
self.dataset = dataset
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = batch_size // self.num_instances
self.index_pid = defaultdict(list)
self.pid_domain = defaultdict(list)
self.pid_index = defaultdict(list)
# data_source: [(img_path, pid, camera, domain), ...] (camera_to_domain = True)
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
dataset.cameras)
for index, info in enumerate(data_source):
domainid = info[3]
if camera_to_domain:
pid = 'p' + str(info[1]) + '_d' + str(domainid)
else:
pid = 'p' + str(info[1])
self.index_pid[index] = pid
self.pid_domain[pid] = domainid
self.pid_index[pid].append(index)
self.pids = list(self.pid_index.keys())
self.domains = list(self.pid_domain.values())
self.num_identities = len(self.pids)
self.num_domains = len(set(self.domains))
self.batch_size //= self.num_domains
self.num_pids_per_batch //= self.num_domains
val_pid_index = [len(x) for x in self.pid_index.values()]
val_pid_index_upper = []
for x in val_pid_index:
v_remain = x % self.num_instances
if v_remain == 0:
val_pid_index_upper.append(x)
else:
val_pid_index_upper.append(x - v_remain + self.num_instances)
cnt_domains = [0 for x in range(self.num_domains)]
for val, index in zip(val_pid_index_upper, self.domains):
cnt_domains[index] += val
self.max_cnt_domains = max(cnt_domains)
self.total_images = self.num_domains * (
self.max_cnt_domains -
(self.max_cnt_domains % self.batch_size) - self.batch_size)
def _get_epoch_indices(self):
def _get_batch_idxs(pids, pid_index, num_instances):
batch_idxs_dict = defaultdict(list)
for pid in pids:
idxs = copy.deepcopy(pid_index[pid])
if len(
idxs
) < self.num_instances: # if idxs is smaller than num_instance, choice redundantly
idxs = np.random.choice(
idxs, size=self.num_instances, replace=True)
elif (len(idxs) % self.num_instances) != 0:
idxs.extend(
np.random.choice(
idxs,
size=self.num_instances - len(idxs) %
self.num_instances,
replace=False))
np.random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(int(idx))
if len(batch_idxs) == num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
return batch_idxs_dict
batch_idxs_dict = _get_batch_idxs(self.pids, self.pid_index,
self.num_instances)
# batch_idxs_dict: dictionary, len(batch_idxs_dict) is len(pidx), each pidx, num_instance x k samples
avai_pids = copy.deepcopy(self.pids)
local_avai_pids = \
[[pids for pids, idx in zip(avai_pids, self.domains) if idx == i]
for i in list(set(self.domains))]
local_avai_pids_save = copy.deepcopy(local_avai_pids)
revive_idx = [False for i in range(self.num_domains)]
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch and not all(
revive_idx):
for i in range(self.num_domains):
selected_pids = np.random.choice(
local_avai_pids[i], self.num_pids_per_batch, replace=False)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
local_avai_pids[i].remove(pid)
for i in range(self.num_domains):
if len(local_avai_pids[i]) < self.num_pids_per_batch:
batch_idxs_dict_new = _get_batch_idxs(
self.pids, self.pid_index, self.num_instances)
revive_idx[i] = True
cnt = 0
for pid, val in batch_idxs_dict_new.items():
if self.domains[cnt] == i:
batch_idxs_dict[pid] = copy.deepcopy(
batch_idxs_dict_new[pid])
cnt += 1
local_avai_pids[i] = copy.deepcopy(local_avai_pids_save[i])
avai_pids.extend(local_avai_pids_save[i])
avai_pids = list(set(avai_pids))
return final_idxs
def __iter__(self):
yield from itertools.islice(self._infinite_indices(), 0, None, 1)
def _infinite_indices(self):
while True:
indices = self._get_epoch_indices()
yield from indices
class DomainShuffleBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, num_instances, camera_to_domain,
drop_last):
sampler = DomainShuffleSampler(
dataset=dataset,
batch_size=batch_size,
num_instances=num_instances,
camera_to_domain=camera_to_domain)
super().__init__(
sampler=sampler, batch_size=batch_size, drop_last=drop_last)
class NaiveIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
Code was heavily based on https://github.com/bismex/MetaBIN
reference: https://arxiv.org/abs/2011.14670v2
"""
def __init__(self, dataset, batch_size, num_instances):
self.dataset = dataset
self.batch_size = batch_size
self.num_instances = num_instances
self.num_pids_per_batch = batch_size // self.num_instances
self.index_pid = defaultdict(list)
self.pid_cam = defaultdict(list)
self.pid_index = defaultdict(list)
# data_source: [(img_path, pid, camera, domain), ...] (camera_to_domain = True)
data_source = zip(dataset.images, dataset.labels, dataset.cameras,
dataset.cameras)
for index, info in enumerate(data_source):
pid = info[1]
camid = info[2]
self.index_pid[index] = pid
self.pid_cam[pid].append(camid)
self.pid_index[pid].append(index)
self.pids = list(self.pid_index.keys())
self.num_identities = len(self.pids)
val_pid_index = [len(x) for x in self.pid_index.values()]
val_pid_index_upper = []
for x in val_pid_index:
v_remain = x % self.num_instances
if v_remain == 0:
val_pid_index_upper.append(x)
else:
val_pid_index_upper.append(x - v_remain + self.num_instances)
total_images = sum(val_pid_index_upper)
total_images = total_images - (total_images % self.batch_size
) - self.batch_size # approax
self.total_images = total_images
def _get_epoch_indices(self):
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(
self.pid_index[pid]) # whole index for each ID
if len(
idxs
) < self.num_instances: # if idxs is smaller than num_instance, choice redundantly
idxs = np.random.choice(
idxs, size=self.num_instances, replace=True)
elif (len(idxs) % self.num_instances) != 0:
idxs.extend(
np.random.choice(
idxs,
size=self.num_instances - len(idxs) %
self.num_instances,
replace=False))
np.random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(int(idx))
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
# batch_idxs_dict: dictionary, len(batch_idxs_dict) is len(pidx), each pidx, num_instance x k samples
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = np.random.choice(
avai_pids, self.num_pids_per_batch, replace=False)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0: avai_pids.remove(pid)
return final_idxs
def __iter__(self):
yield from itertools.islice(self._infinite_indices(), 0, None, 1)
def _infinite_indices(self):
while True:
indices = self._get_epoch_indices()
yield from indices
def __len__(self):
return self.total_images
class NaiveIdentityBatchSampler(BatchSampler):
def __init__(self, dataset, batch_size, num_instances, drop_last):
sampler = NaiveIdentitySampler(
dataset=dataset,
batch_size=batch_size,
num_instances=num_instances)
super().__init__(
sampler=sampler, batch_size=batch_size, drop_last=drop_last)
......@@ -204,6 +204,96 @@ class MSMT17(Dataset):
return len(set(self.labels))
class DukeMTMC(Dataset):
"""
DukeMTMC-reID.
Reference:
Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016.
Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017.
URL: https://github.com/layumi/DukeMTMC-reID_evaluation
Dataset statistics:
# identities: 1404 (train + query)
# images: 16522 (train) + 2228 (query) + 17661 (gallery)
# cameras: 8
"""
_dataset_dir = 'dukemtmc/DukeMTMC-reID'
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2"):
self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir,
self._cls_path)
self._check_before_run()
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False)
def _check_before_run(self):
"""Check if the file is available before going deeper"""
if not osp.exists(self._dataset_dir):
raise RuntimeError("'{}' is not available".format(
self._dataset_dir))
def _load_anno(self, relabel=False):
img_paths = glob.glob(osp.join(self._dataset_dir, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d+)')
self.images = []
self.labels = []
self.cameras = []
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
for img_path in img_paths:
pid, camid = map(int, pattern.search(img_path).groups())
assert 1 <= camid <= 8
camid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
self.images.append(img_path)
self.labels.append(pid)
self.cameras.append(camid)
self.num_pids, self.num_imgs, self.num_cams = get_imagedata_info(
self.images, self.labels, self.cameras, subfolder=self._cls_path)
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
format(self.images[idx], ex))
rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
def get_imagedata_info(data, labels, cameras, subfolder='train'):
pids, cams = [], []
for _, pid, camid in zip(data, labels, cameras):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册