From 806ec9a1171bf13d1c8426a50cccce72880fcbf9 Mon Sep 17 00:00:00 2001 From: Felix Date: Mon, 31 May 2021 13:01:38 +0800 Subject: [PATCH] Add files via upload --- .../DistributedRandomIdentitySampler.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 ppcls/data/samplers/DistributedRandomIdentitySampler.py diff --git a/ppcls/data/samplers/DistributedRandomIdentitySampler.py b/ppcls/data/samplers/DistributedRandomIdentitySampler.py new file mode 100644 index 00000000..1c84b746 --- /dev/null +++ b/ppcls/data/samplers/DistributedRandomIdentitySampler.py @@ -0,0 +1,76 @@ +from __future__ import absolute_import +from __future__ import division +from collections import defaultdict +import numpy as np +import copy +import random +from paddle.io import DistributedBatchSampler, Sampler + + +class DistributedRandomIdentitySampler(DistributedBatchSampler): + """ + Randomly sample N identities, then for each identity, + randomly sample K instances, therefore batch size is N*K. + Args: + - data_source (list): list of (img_path, pid, camid). + - num_instances (int): number of instances per identity in a batch. + - batch_size (int): number of examples in a batch. + """ + + def __init__(self, dataset, batch_size, num_instances, drop_last, **args): + self.dataset = dataset + self.batch_size = batch_size + self.num_instances = num_instances + self.drop_last = drop_last + self.num_pids_per_batch = self.batch_size // self.num_instances + self.index_dic = defaultdict(list) + for index, pid in enumerate(self.dataset.labels): + self.index_dic[pid].append(index) + self.pids = list(self.index_dic.keys()) + # estimate number of examples in an epoch + self.length = 0 + for pid in self.pids: + idxs = self.index_dic[pid] + num = len(idxs) + if num < self.num_instances: + num = self.num_instances + self.length += num - num % self.num_instances + + def __iter__(self): + batch_idxs_dict = defaultdict(list) + for pid in self.pids: + idxs = copy.deepcopy(self.index_dic[pid]) + if len(idxs) < self.num_instances: + idxs = np.random.choice( + idxs, size=self.num_instances, replace=True) + random.shuffle(idxs) + batch_idxs = [] + for idx in idxs: + batch_idxs.append(idx) + if len(batch_idxs) == self.num_instances: + batch_idxs_dict[pid].append(batch_idxs) + batch_idxs = [] + avai_pids = copy.deepcopy(self.pids) + final_idxs = [] + while len(avai_pids) >= self.num_pids_per_batch: + selected_pids = random.sample(avai_pids, self.num_pids_per_batch) + for pid in selected_pids: + batch_idxs = batch_idxs_dict[pid].pop(0) + final_idxs.extend(batch_idxs) + if len(batch_idxs_dict[pid]) == 0: + avai_pids.remove(pid) + _sample_iter = iter(final_idxs) + batch_indices = [] + for idx in _sample_iter: + batch_indices.append(idx) + if len(batch_indices) == self.batch_size: + yield batch_indices + batch_indices = [] + if not self.drop_last and len(batch_indices) > 0: + yield batch_indices + + def __len__(self): + if self.drop_last: + return self.length // self.batch_size + else: + return (self.length + self.batch_size - 1) // self.batch_size -- GitLab