pk_sampler.py 5.3 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
17

W
weishengyu 已提交
18
from collections import defaultdict
19

W
weishengyu 已提交
20 21 22 23 24 25
import numpy as np
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger


class PKSampler(DistributedBatchSampler):
26 27 28 29
    """First, randomly sample P identities.
        Then for each identity randomly sample K instances.
        Therefore batch size equals to P * K, and the sampler called PKSampler.

W
weishengyu 已提交
30
    Args:
31
        dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
H
HydrogenSulfate 已提交
32 33
        batch_size (int): batch size
        sample_per_id (int): number of instance(s) within an class
34
        shuffle (bool, optional): _description_. Defaults to True.
H
HydrogenSulfate 已提交
35 36
        id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated.
        ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
37 38
        drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
        sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
W
weishengyu 已提交
39
    """
H
HydrogenSulfate 已提交
40

W
weishengyu 已提交
41 42 43 44 45
    def __init__(self,
                 dataset,
                 batch_size,
                 sample_per_id,
                 shuffle=True,
W
weishengyu 已提交
46
                 drop_last=True,
H
HydrogenSulfate 已提交
47 48
                 id_list=None,
                 ratio=None,
W
weishengyu 已提交
49
                 sample_method="sample_avg_prob"):
H
HydrogenSulfate 已提交
50 51
        super().__init__(
            dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
W
weishengyu 已提交
52
        assert batch_size % sample_per_id == 0, \
53
            f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
W
weishengyu 已提交
54 55
        assert hasattr(self.dataset,
                       "labels"), "Dataset must have labels attribute."
56
        self.sample_per_id = sample_per_id
W
weishengyu 已提交
57
        self.label_dict = defaultdict(list)
W
weishengyu 已提交
58
        self.sample_method = sample_method
W
weishengyu 已提交
59 60 61
        for idx, label in enumerate(self.dataset.labels):
            self.label_dict[label].append(idx)
        self.label_list = list(self.label_dict)
62 63
        assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
            f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
W
weishengyu 已提交
64
        if self.sample_method == "id_avg_prob":
W
weishengyu 已提交
65 66
            self.prob_list = np.array([1 / len(self.label_list)] *
                                      len(self.label_list))
W
weishengyu 已提交
67
        elif self.sample_method == "sample_avg_prob":
W
weishengyu 已提交
68 69
            counter = []
            for label_i in self.label_list:
W
dbg  
weishengyu 已提交
70
                counter.append(len(self.label_dict[label_i]))
W
weishengyu 已提交
71
            self.prob_list = np.array(counter) / sum(counter)
W
weishengyu 已提交
72 73 74 75
        else:
            logger.error(
                "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
                "but receive {}.".format(self.sample_method))
H
HydrogenSulfate 已提交
76 77 78 79 80 81 82 83 84 85

        if id_list and ratio:
            assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2
            for i in range(len(self.prob_list)):
                for j in range(len(ratio)):
                    if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]:
                        self.prob_list[i] = self.prob_list[i] * ratio[j]
                        break
            self.prob_list = self.prob_list / sum(self.prob_list)

W
dbg  
weishengyu 已提交
86 87
        diff = np.abs(sum(self.prob_list) - 1)
        if diff > 0.00000001:
W
weishengyu 已提交
88 89 90 91 92
            self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
            if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
                logger.error("PKSampler prob list error")
            else:
                logger.info(
H
HydrogenSulfate 已提交
93 94
                    "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".
                    format(diff))
W
weishengyu 已提交
95 96

    def __iter__(self):
97
        label_per_batch = self.batch_size // self.sample_per_id
W
weishengyu 已提交
98
        for _ in range(len(self)):
W
weishengyu 已提交
99
            batch_index = []
W
dbg  
weishengyu 已提交
100
            batch_label_list = np.random.choice(
W
weishengyu 已提交
101 102 103 104 105 106
                self.label_list,
                size=label_per_batch,
                replace=False,
                p=self.prob_list)
            for label_i in batch_label_list:
                label_i_indexes = self.label_dict[label_i]
107
                if self.sample_per_id <= len(label_i_indexes):
W
weishengyu 已提交
108 109
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
110
                            label_i_indexes,
111
                            size=self.sample_per_id,
W
weishengyu 已提交
112 113 114 115
                            replace=False))
                else:
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
116
                            label_i_indexes,
117
                            size=self.sample_per_id,
W
weishengyu 已提交
118 119 120
                            replace=True))
            if not self.drop_last or len(batch_index) == self.batch_size:
                yield batch_index