From 9e97569953318a738c598cc682b4c0f145d8ebf4 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Wed, 22 Sep 2021 14:29:33 +0800 Subject: [PATCH] update sample method --- ppcls/data/dataloader/pk_sampler.py | 32 ++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index 93762ad7..f78bdbd4 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -40,7 +40,8 @@ class PKSampler(DistributedBatchSampler): batch_size, sample_per_id, shuffle=True, - drop_last=True): + drop_last=True, + sample_method="sample_avg_prob"): super(PKSampler, self).__init__( dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ @@ -49,16 +50,33 @@ class PKSampler(DistributedBatchSampler): "labels"), "Dataset must have labels attribute." self.sample_per_id = sample_per_id self.label_dict = defaultdict(list) - for idx, label in enumerate(self.dataset.labels): - self.label_dict[label].append(idx) - self.id_list = list(self.label_dict) + self.sample_method = sample_method + if self.sample_method == "id_avg_prob": + for idx, label in enumerate(self.dataset.labels): + self.label_dict[label].append(idx) + self.id_list = list(self.label_dict) + elif self.sample_method == "sample_avg_prob": + self.id_list = [] + for idx, label in enumerate(self.dataset.labels): + self.label_dict[label].append(idx) + else: + logger.error( + "PKSampler only support id_avg_prob and sample_avg_prob sample method, " + "but receive {}.".format(self.sample_method)) def __iter__(self): if self.shuffle: np.random.RandomState(self.epoch).shuffle(self.id_list) - id_list = self.id_list[self.local_rank * len(self):(self.local_rank + 1 - ) * len(self)] - id_per_batch = self.batch_size / self.sample_per_id + id_list = self.id_list[self.local_rank * len(self.id_list) // + self.nranks:(self.local_rank + 1) * len( + self.id_list) // self.nranks] + if self.sample_method == "id_avg_prob": + id_batch_num = len(id_list) * self.sample_per_id // self.batch_size + if id_batch_num < len(self): + id_list = id_list * (len(self) // id_batch_num + 1) + id_list = id_list[0:len(self)] + + id_per_batch = self.batch_size // self.sample_per_id for i in range(len(self)): batch_index = [] for label_id in id_list[i * id_per_batch:(i + 1) * id_per_batch]: -- GitLab