pk_sampler.py 5.6 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
17

W
weishengyu 已提交
18
from collections import defaultdict
19

W
weishengyu 已提交
20
import numpy as np
21
import paddle.distributed as dist
W
weishengyu 已提交
22
from paddle.io import DistributedBatchSampler
23

W
weishengyu 已提交
24 25 26 27
from ppcls.utils import logger


class PKSampler(DistributedBatchSampler):
28 29 30 31
    """First, randomly sample P identities.
        Then for each identity randomly sample K instances.
        Therefore batch size equals to P * K, and the sampler called PKSampler.

W
weishengyu 已提交
32
    Args:
33
        dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
H
HydrogenSulfate 已提交
34 35
        batch_size (int): batch size
        sample_per_id (int): number of instance(s) within an class
36
        shuffle (bool, optional): _description_. Defaults to True.
H
HydrogenSulfate 已提交
37 38
        id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated.
        ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
39 40
        drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
        sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
W
weishengyu 已提交
41
    """
H
HydrogenSulfate 已提交
42

W
weishengyu 已提交
43 44 45 46 47
    def __init__(self,
                 dataset,
                 batch_size,
                 sample_per_id,
                 shuffle=True,
W
weishengyu 已提交
48
                 drop_last=True,
H
HydrogenSulfate 已提交
49 50
                 id_list=None,
                 ratio=None,
W
weishengyu 已提交
51
                 sample_method="sample_avg_prob"):
H
HydrogenSulfate 已提交
52 53
        super().__init__(
            dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
W
weishengyu 已提交
54
        assert batch_size % sample_per_id == 0, \
55
            f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
W
weishengyu 已提交
56 57
        assert hasattr(self.dataset,
                       "labels"), "Dataset must have labels attribute."
58
        self.sample_per_id = sample_per_id
W
weishengyu 已提交
59
        self.label_dict = defaultdict(list)
W
weishengyu 已提交
60
        self.sample_method = sample_method
W
weishengyu 已提交
61 62 63
        for idx, label in enumerate(self.dataset.labels):
            self.label_dict[label].append(idx)
        self.label_list = list(self.label_dict)
64 65
        assert len(self.label_list) * self.sample_per_id >= self.batch_size, \
            f"batch size({self.batch_size}) should not be bigger than than #classes({len(self.label_list)})*sample_per_id({self.sample_per_id})"
W
weishengyu 已提交
66
        if self.sample_method == "id_avg_prob":
W
weishengyu 已提交
67 68
            self.prob_list = np.array([1 / len(self.label_list)] *
                                      len(self.label_list))
W
weishengyu 已提交
69
        elif self.sample_method == "sample_avg_prob":
W
weishengyu 已提交
70 71
            counter = []
            for label_i in self.label_list:
W
dbg  
weishengyu 已提交
72
                counter.append(len(self.label_dict[label_i]))
W
weishengyu 已提交
73
            self.prob_list = np.array(counter) / sum(counter)
W
weishengyu 已提交
74 75 76 77
        else:
            logger.error(
                "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
                "but receive {}.".format(self.sample_method))
H
HydrogenSulfate 已提交
78 79 80 81 82 83 84 85 86 87

        if id_list and ratio:
            assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2
            for i in range(len(self.prob_list)):
                for j in range(len(ratio)):
                    if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]:
                        self.prob_list[i] = self.prob_list[i] * ratio[j]
                        break
            self.prob_list = self.prob_list / sum(self.prob_list)

W
dbg  
weishengyu 已提交
88 89
        diff = np.abs(sum(self.prob_list) - 1)
        if diff > 0.00000001:
W
weishengyu 已提交
90 91 92 93 94
            self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
            if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
                logger.error("PKSampler prob list error")
            else:
                logger.info(
H
HydrogenSulfate 已提交
95 96
                    "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".
                    format(diff))
W
weishengyu 已提交
97 98

    def __iter__(self):
99 100 101 102 103
        # shuffing label_list manually in distributed environment
        if self.nranks > 1:
            cur_rank = dist.get_rank()
            np.random.RandomState(42 + cur_rank).shuffle(self.label_list)

104
        label_per_batch = self.batch_size // self.sample_per_id
W
weishengyu 已提交
105
        for _ in range(len(self)):
W
weishengyu 已提交
106
            batch_index = []
W
dbg  
weishengyu 已提交
107
            batch_label_list = np.random.choice(
W
weishengyu 已提交
108 109 110 111 112 113
                self.label_list,
                size=label_per_batch,
                replace=False,
                p=self.prob_list)
            for label_i in batch_label_list:
                label_i_indexes = self.label_dict[label_i]
114
                if self.sample_per_id <= len(label_i_indexes):
W
weishengyu 已提交
115 116
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
117
                            label_i_indexes,
118
                            size=self.sample_per_id,
W
weishengyu 已提交
119 120 121 122
                            replace=False))
                else:
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
123
                            label_i_indexes,
124
                            size=self.sample_per_id,
W
weishengyu 已提交
125 126 127
                            replace=True))
            if not self.drop_last or len(batch_index) == self.batch_size:
                yield batch_index