From 0bdb16af8b9a0f7e7f64ec77a8bd46c911446024 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Fri, 17 Sep 2021 18:08:35 +0800 Subject: [PATCH] add pk_sampler --- ppcls/data/dataloader/pk_sampler.py | 79 +++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 ppcls/data/dataloader/pk_sampler.py diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py new file mode 100644 index 00000000..93762ad7 --- /dev/null +++ b/ppcls/data/dataloader/pk_sampler.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from collections import defaultdict +import numpy as np +import random +from paddle.io import DistributedBatchSampler + +from ppcls.utils import logger + + +class PKSampler(DistributedBatchSampler): + """ + First, randomly sample P identities. + Then for each identity randomly sample K instances. + Therefore batch size is P*K, and the sampler called PKSampler. + Args: + dataset (paddle.io.Dataset): list of (img_path, pid, cam_id). + sample_per_id(int): number of instances per identity in a batch. + batch_size (int): number of examples in a batch. + shuffle(bool): whether to shuffle indices order before generating + batch indices. Default False. + """ + + def __init__(self, + dataset, + batch_size, + sample_per_id, + shuffle=True, + drop_last=True): + super(PKSampler, self).__init__( + dataset, batch_size, shuffle=shuffle, drop_last=drop_last) + assert batch_size % sample_per_id == 0, \ + "PKSampler configs error, Sample_per_id must be a divisor of batch_size." + assert hasattr(self.dataset, + "labels"), "Dataset must have labels attribute." + self.sample_per_id = sample_per_id + self.label_dict = defaultdict(list) + for idx, label in enumerate(self.dataset.labels): + self.label_dict[label].append(idx) + self.id_list = list(self.label_dict) + + def __iter__(self): + if self.shuffle: + np.random.RandomState(self.epoch).shuffle(self.id_list) + id_list = self.id_list[self.local_rank * len(self):(self.local_rank + 1 + ) * len(self)] + id_per_batch = self.batch_size / self.sample_per_id + for i in range(len(self)): + batch_index = [] + for label_id in id_list[i * id_per_batch:(i + 1) * id_per_batch]: + idx_label_list = self.label_dict[label_id] + if self.sample_per_id <= len(idx_label_list): + batch_index.extend( + np.random.choice( + idx_label_list, + size=self.sample_per_id, + replace=False)) + else: + batch_index.extend( + np.random.choice( + idx_label_list, + size=self.sample_per_id, + replace=True)) + if not self.drop_last or len(batch_index) == self.batch_size: + yield batch_index -- GitLab