diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 27f1cbaf03a6f1e9f947113d8db4552bae744f0c..e8d2a1fae196b7de5ebbea11d7cf2a62289d610c 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -14,8 +14,11 @@ import inspect import copy +import random import paddle import numpy as np +import paddle.distributed as dist +from functools import partial from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader from ppcls.utils import logger @@ -66,6 +69,22 @@ def create_operators(params, class_num=None): return ops +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """callback function on each worker subprocess after seeding and before data loading. + + Args: + worker_id (int): Worker id in [0, num_workers - 1] + num_workers (int): Number of subprocesses to use for data loading. + rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`. + seed (int): Random seed + """ + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + + def build_dataloader(config, mode, device, use_dali=False, seed=None): assert mode in [ 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' @@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): seed=seed) class_num = config.get("class_num", None) + epochs = config.get("epochs", None) config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') @@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): shuffle = config_sampler["shuffle"] else: sampler_name = config_sampler.pop("name") + sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args + if "total_epochs" in sampler_argspec: + config_sampler.update({"total_epochs": epochs}) batch_sampler = eval(sampler_name)(dataset, **config_sampler) logger.debug("build batch_sampler({}) success...".format(batch_sampler)) @@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): num_workers = config_loader["num_workers"] use_shared_memory = config_loader["use_shared_memory"] + init_fn = partial( + worker_init_fn, + num_workers=num_workers, + rank=dist.get_rank(), + seed=seed) if seed is not None else None + if batch_sampler is None: data_loader = DataLoader( dataset=dataset, @@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - collate_fn=batch_collate_fn) + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) else: data_loader = DataLoader( dataset=dataset, @@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): return_list=True, use_shared_memory=use_shared_memory, batch_sampler=batch_sampler, - collate_fn=batch_collate_fn) + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index aa69caa77282d4fb46524dfd65093fc6b281c9e7..cbbc0919f16a493fdf5342dba4b509d4996e4e9f 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler): ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list. drop_last (bool, optional): whether to discard the data at the end. Defaults to True. sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". + total_epochs (int, optional): total epochs. Defaults to 0. """ def __init__(self, @@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler): drop_last=True, id_list=None, ratio=None, - sample_method="sample_avg_prob"): + sample_method="sample_avg_prob", + total_epochs=0): super().__init__( dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ @@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler): self.sample_per_id = sample_per_id self.label_dict = defaultdict(list) self.sample_method = sample_method + self.total_epochs = total_epochs for idx, label in enumerate(self.dataset.labels): self.label_dict[label].append(idx) self.label_list = list(self.label_dict) @@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler): def __iter__(self): # shuffle manually, same as DistributedBatchSampler.__iter__ if self.shuffle: - np.random.RandomState(self.epoch + dist.get_rank()).shuffle( - self.label_list) + rank = dist.get_rank() + np.random.RandomState(rank * self.total_epochs + + self.epoch).shuffle(self.label_list) self.epoch += 1 label_per_batch = self.batch_size // self.sample_per_id diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 08ad02e133b18dfbda7f1c1af5d7409fbad90638..3d255a3ab781442541b2f5df6d8af44a8bbe8079 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -119,6 +119,9 @@ class Engine(object): #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) + self.config["DataLoader"].update({ + "epochs": self.config["Global"]["epochs"] + }) # build dataloader if self.mode == 'train':