diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index 7f718a33350a63323e556185a126c13a07a4674b..b3632bc72582243a8b041a4fd7b6b3fa1f0e7b88 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -68,18 +68,20 @@ class PKSampler(DistributedBatchSampler): logger.error( "PKSampler only support id_avg_prob and sample_avg_prob sample method, " "but receive {}.".format(self.sample_method)) - if sum(np.abs(self.prob_list - 1) > 0.00000001): + diff = np.abs(sum(self.prob_list) - 1) + if diff > 0.00000001: self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) if self.prob_list[-1] > 1 or self.prob_list[-1] < 0: logger.error("PKSampler prob list error") else: logger.info( - "PKSampler: sum of prob list not equal to 1, change the last prob" + "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff) ) def __iter__(self): label_per_batch = self.batch_size // self.sample_per_label if self.shuffle: + # It's not accurate literally, but it helps in some dataset. np.random.RandomState(self.epoch).shuffle(self.label_list) for i in range(len(self)): batch_index = []