pk_sampler.py 5.8 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
17

W
weishengyu 已提交
18
from collections import defaultdict
19

W
weishengyu 已提交
20
import numpy as np
H
HydrogenSulfate 已提交
21
import paddle.distributed as dist
W
weishengyu 已提交
22
from paddle.io import DistributedBatchSampler
H
HydrogenSulfate 已提交
23

W
weishengyu 已提交
24 25 26 27
from ppcls.utils import logger


class PKSampler(DistributedBatchSampler):
28 29 30 31
    """First, randomly sample P identities.
        Then for each identity randomly sample K instances.
        Therefore batch size equals to P * K, and the sampler called PKSampler.

W
weishengyu 已提交
32
    Args:
33
        dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
H
HydrogenSulfate 已提交
34 35
        batch_size (int): batch size
        sample_per_id (int): number of instance(s) within an class
36
        shuffle (bool, optional): _description_. Defaults to True.
H
HydrogenSulfate 已提交
37 38
        id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated.
        ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
39 40
        drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
        sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
41
        total_epochs (int, optional): total epochs. Defaults to 0.
W
weishengyu 已提交
42
    """
H
HydrogenSulfate 已提交
43

W
weishengyu 已提交
44 45 46 47 48
    def __init__(self,
                 dataset,
                 batch_size,
                 sample_per_id,
                 shuffle=True,
W
weishengyu 已提交
49
                 drop_last=True,
H
HydrogenSulfate 已提交
50 51
                 id_list=None,
                 ratio=None,
52 53
                 sample_method="sample_avg_prob",
                 total_epochs=0):
H
HydrogenSulfate 已提交
54 55
        super().__init__(
            dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
W
weishengyu 已提交
56
        assert batch_size % sample_per_id == 0, \
57
            f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
W
weishengyu 已提交
58 59
        assert hasattr(self.dataset,
                       "labels"), "Dataset must have labels attribute."
60
        self.sample_per_id = sample_per_id
W
weishengyu 已提交
61
        self.label_dict = defaultdict(list)
W
weishengyu 已提交
62
        self.sample_method = sample_method
63
        self.total_epochs = total_epochs
W
weishengyu 已提交
64 65 66
        for idx, label in enumerate(self.dataset.labels):
            self.label_dict[label].append(idx)
        self.label_list = list(self.label_dict)
67 68
        assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
            f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
W
weishengyu 已提交
69
        if self.sample_method == "id_avg_prob":
W
weishengyu 已提交
70 71
            self.prob_list = np.array([1 / len(self.label_list)] *
                                      len(self.label_list))
W
weishengyu 已提交
72
        elif self.sample_method == "sample_avg_prob":
W
weishengyu 已提交
73 74
            counter = []
            for label_i in self.label_list:
W
dbg  
weishengyu 已提交
75
                counter.append(len(self.label_dict[label_i]))
W
weishengyu 已提交
76
            self.prob_list = np.array(counter) / sum(counter)
W
weishengyu 已提交
77 78 79 80
        else:
            logger.error(
                "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
                "but receive {}.".format(self.sample_method))
H
HydrogenSulfate 已提交
81 82 83 84 85 86 87 88 89 90

        if id_list and ratio:
            assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2
            for i in range(len(self.prob_list)):
                for j in range(len(ratio)):
                    if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]:
                        self.prob_list[i] = self.prob_list[i] * ratio[j]
                        break
            self.prob_list = self.prob_list / sum(self.prob_list)

W
dbg  
weishengyu 已提交
91 92
        diff = np.abs(sum(self.prob_list) - 1)
        if diff > 0.00000001:
W
weishengyu 已提交
93 94 95 96 97
            self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
            if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
                logger.error("PKSampler prob list error")
            else:
                logger.info(
H
HydrogenSulfate 已提交
98 99
                    "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".
                    format(diff))
W
weishengyu 已提交
100 101

    def __iter__(self):
H
HydrogenSulfate 已提交
102 103
        # shuffle manually, same as DistributedBatchSampler.__iter__
        if self.shuffle:
104 105 106
            rank = dist.get_rank()
            np.random.RandomState(rank * self.total_epochs +
                                  self.epoch).shuffle(self.label_list)
H
HydrogenSulfate 已提交
107 108
            self.epoch += 1

109
        label_per_batch = self.batch_size // self.sample_per_id
W
weishengyu 已提交
110
        for _ in range(len(self)):
W
weishengyu 已提交
111
            batch_index = []
W
dbg  
weishengyu 已提交
112
            batch_label_list = np.random.choice(
W
weishengyu 已提交
113 114 115 116 117 118
                self.label_list,
                size=label_per_batch,
                replace=False,
                p=self.prob_list)
            for label_i in batch_label_list:
                label_i_indexes = self.label_dict[label_i]
119
                if self.sample_per_id <= len(label_i_indexes):
W
weishengyu 已提交
120 121
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
122
                            label_i_indexes,
123
                            size=self.sample_per_id,
W
weishengyu 已提交
124 125 126 127
                            replace=False))
                else:
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
128
                            label_i_indexes,
129
                            size=self.sample_per_id,
W
weishengyu 已提交
130 131 132
                            replace=True))
            if not self.drop_last or len(batch_index) == self.batch_size:
                yield batch_index