pk_sampler.py 4.6 KB
Newer Older
W
weishengyu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
17

W
weishengyu 已提交
18
from collections import defaultdict
19

W
weishengyu 已提交
20 21 22 23 24 25
import numpy as np
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger


class PKSampler(DistributedBatchSampler):
26 27 28 29
    """First, randomly sample P identities.
        Then for each identity randomly sample K instances.
        Therefore batch size equals to P * K, and the sampler called PKSampler.

W
weishengyu 已提交
30
    Args:
31 32 33 34 35 36
        dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
        batch_size (_type_): batch size
        sample_per_id (_type_): number of instance(s) within an class
        shuffle (bool, optional): _description_. Defaults to True.
        drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
        sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
W
weishengyu 已提交
37 38 39 40 41 42
    """
    def __init__(self,
                 dataset,
                 batch_size,
                 sample_per_id,
                 shuffle=True,
W
weishengyu 已提交
43 44
                 drop_last=True,
                 sample_method="sample_avg_prob"):
45
        super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
W
weishengyu 已提交
46
        assert batch_size % sample_per_id == 0, \
47
            f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
W
weishengyu 已提交
48 49
        assert hasattr(self.dataset,
                       "labels"), "Dataset must have labels attribute."
W
weishengyu 已提交
50
        self.sample_per_label = sample_per_id
W
weishengyu 已提交
51
        self.label_dict = defaultdict(list)
W
weishengyu 已提交
52
        self.sample_method = sample_method
W
weishengyu 已提交
53 54 55 56 57
        for idx, label in enumerate(self.dataset.labels):
            self.label_dict[label].append(idx)
        self.label_list = list(self.label_dict)
        assert len(self.label_list) * self.sample_per_label > self.batch_size, \
            "batch size should be smaller than "
W
weishengyu 已提交
58
        if self.sample_method == "id_avg_prob":
W
weishengyu 已提交
59 60
            self.prob_list = np.array([1 / len(self.label_list)] *
                                      len(self.label_list))
W
weishengyu 已提交
61
        elif self.sample_method == "sample_avg_prob":
W
weishengyu 已提交
62 63
            counter = []
            for label_i in self.label_list:
W
dbg  
weishengyu 已提交
64
                counter.append(len(self.label_dict[label_i]))
W
weishengyu 已提交
65
            self.prob_list = np.array(counter) / sum(counter)
W
weishengyu 已提交
66 67 68 69
        else:
            logger.error(
                "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
                "but receive {}.".format(self.sample_method))
W
dbg  
weishengyu 已提交
70 71
        diff = np.abs(sum(self.prob_list) - 1)
        if diff > 0.00000001:
W
weishengyu 已提交
72 73 74 75 76
            self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
            if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
                logger.error("PKSampler prob list error")
            else:
                logger.info(
W
dbg  
weishengyu 已提交
77
                    "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff)
W
weishengyu 已提交
78
                )
W
weishengyu 已提交
79 80

    def __iter__(self):
W
weishengyu 已提交
81
        label_per_batch = self.batch_size // self.sample_per_label
W
weishengyu 已提交
82
        for _ in range(len(self)):
W
weishengyu 已提交
83
            batch_index = []
W
dbg  
weishengyu 已提交
84
            batch_label_list = np.random.choice(
W
weishengyu 已提交
85 86 87 88 89 90 91
                self.label_list,
                size=label_per_batch,
                replace=False,
                p=self.prob_list)
            for label_i in batch_label_list:
                label_i_indexes = self.label_dict[label_i]
                if self.sample_per_label <= len(label_i_indexes):
W
weishengyu 已提交
92 93
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
94 95
                            label_i_indexes,
                            size=self.sample_per_label,
W
weishengyu 已提交
96 97 98 99
                            replace=False))
                else:
                    batch_index.extend(
                        np.random.choice(
W
weishengyu 已提交
100 101
                            label_i_indexes,
                            size=self.sample_per_label,
W
weishengyu 已提交
102 103 104
                            replace=True))
            if not self.drop_last or len(batch_index) == self.batch_size:
                yield batch_index