From 8595d189454ad153bc62e8e6e6dba8273959b11c Mon Sep 17 00:00:00 2001 From: weishengyu Date: Sun, 26 Sep 2021 16:07:05 +0800 Subject: [PATCH] update format --- ppcls/data/dataloader/mix_dataset.py | 2 +- ppcls/data/dataloader/mix_sampler.py | 14 +++++++++----- ppcls/data/dataloader/pk_sampler.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/ppcls/data/dataloader/mix_dataset.py b/ppcls/data/dataloader/mix_dataset.py index c928ca26..cbf4b402 100644 --- a/ppcls/data/dataloader/mix_dataset.py +++ b/ppcls/data/dataloader/mix_dataset.py @@ -23,7 +23,7 @@ from .. import dataloader class MixDataset(Dataset): def __init__(self, datasets_config): - super(MixDataset, self).__init__() + super().__init__() self.dataset_list = [] start_idx = 0 end_idx = 0 diff --git a/ppcls/data/dataloader/mix_sampler.py b/ppcls/data/dataloader/mix_sampler.py index d4206be7..2df3109c 100644 --- a/ppcls/data/dataloader/mix_sampler.py +++ b/ppcls/data/dataloader/mix_sampler.py @@ -24,8 +24,9 @@ from ppcls.data import dataloader class MixSampler(DistributedBatchSampler): def __init__(self, dataset, batch_size, sample_configs, iter_per_epoch): - super(MixSampler, self).__init__(dataset, batch_size) - assert isinstance(dataset, MixDataset), "MixSampler only support MixDataset" + super().__init__(dataset, batch_size) + assert isinstance(dataset, + MixDataset), "MixSampler only support MixDataset" self.sampler_list = [] self.batch_size = batch_size self.start_list = [] @@ -45,9 +46,11 @@ class MixSampler(DistributedBatchSampler): assert batch_size_i <= len(dataset_list[i][2]) config_i["batch_size"] = batch_size_i if sample_method == "DistributedBatchSampler": - sampler_i = DistributedBatchSampler(dataset_list[i][2], **config_i) + sampler_i = DistributedBatchSampler(dataset_list[i][2], + **config_i) else: - sampler_i = getattr(dataloader, sample_method)(dataset_list[i][2], **config_i) + sampler_i = getattr(dataloader, sample_method)( + dataset_list[i][2], **config_i) self.sampler_list.append(sampler_i) self.iter_list.append(iter(sampler_i)) self.length += len(dataset_list[i][2]) * ratio_i @@ -62,7 +65,8 @@ class MixSampler(DistributedBatchSampler): iter_i = iter(self.sampler_list[i]) self.iter_list[i] = iter_i batch_i = next(iter_i, None) - assert batch_i is not None, "dataset {} return None".format(i) + assert batch_i is not None, "dataset {} return None".format( + i) batch += [idx + self.start_list[i] for idx in batch_i] if len(batch) == self.batch_size: self.iter_counter += 1 diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index ef65be94..7f718a33 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -42,7 +42,7 @@ class PKSampler(DistributedBatchSampler): shuffle=True, drop_last=True, sample_method="sample_avg_prob"): - super(PKSampler, self).__init__( + super().__init__( dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ "PKSampler configs error, Sample_per_id must be a divisor of batch_size." -- GitLab