diff --git a/ppcls/data/dataloader/mix_dataset.py b/ppcls/data/dataloader/mix_dataset.py index c928ca26df0278da378796a70ff5b4df8b658e14..cbf4b4028d27cf3ebfeab9dc89ac5414dbd4786e 100644 --- a/ppcls/data/dataloader/mix_dataset.py +++ b/ppcls/data/dataloader/mix_dataset.py @@ -23,7 +23,7 @@ from .. import dataloader class MixDataset(Dataset): def __init__(self, datasets_config): - super(MixDataset, self).__init__() + super().__init__() self.dataset_list = [] start_idx = 0 end_idx = 0 diff --git a/ppcls/data/dataloader/mix_sampler.py b/ppcls/data/dataloader/mix_sampler.py index d4206be750d2f207f3442986159e07f722d8e6f3..2df3109cece3e6532ac54eb8f1d9e6498a1f33a7 100644 --- a/ppcls/data/dataloader/mix_sampler.py +++ b/ppcls/data/dataloader/mix_sampler.py @@ -24,8 +24,9 @@ from ppcls.data import dataloader class MixSampler(DistributedBatchSampler): def __init__(self, dataset, batch_size, sample_configs, iter_per_epoch): - super(MixSampler, self).__init__(dataset, batch_size) - assert isinstance(dataset, MixDataset), "MixSampler only support MixDataset" + super().__init__(dataset, batch_size) + assert isinstance(dataset, + MixDataset), "MixSampler only support MixDataset" self.sampler_list = [] self.batch_size = batch_size self.start_list = [] @@ -45,9 +46,11 @@ class MixSampler(DistributedBatchSampler): assert batch_size_i <= len(dataset_list[i][2]) config_i["batch_size"] = batch_size_i if sample_method == "DistributedBatchSampler": - sampler_i = DistributedBatchSampler(dataset_list[i][2], **config_i) + sampler_i = DistributedBatchSampler(dataset_list[i][2], + **config_i) else: - sampler_i = getattr(dataloader, sample_method)(dataset_list[i][2], **config_i) + sampler_i = getattr(dataloader, sample_method)( + dataset_list[i][2], **config_i) self.sampler_list.append(sampler_i) self.iter_list.append(iter(sampler_i)) self.length += len(dataset_list[i][2]) * ratio_i @@ -62,7 +65,8 @@ class MixSampler(DistributedBatchSampler): iter_i = iter(self.sampler_list[i]) self.iter_list[i] = iter_i batch_i = next(iter_i, None) - assert batch_i is not None, "dataset {} return None".format(i) + assert batch_i is not None, "dataset {} return None".format( + i) batch += [idx + self.start_list[i] for idx in batch_i] if len(batch) == self.batch_size: self.iter_counter += 1 diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index ef65be940e2ac5ecb0f34f51b4ccfd42da7325c2..7f718a33350a63323e556185a126c13a07a4674b 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -42,7 +42,7 @@ class PKSampler(DistributedBatchSampler): shuffle=True, drop_last=True, sample_method="sample_avg_prob"): - super(PKSampler, self).__init__( + super().__init__( dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ "PKSampler configs error, Sample_per_id must be a divisor of batch_size."