From 6cc6540ca5bcde78ee34592b0059bd81fc86189a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 24 Nov 2022 00:28:59 +0800 Subject: [PATCH] add different seed for workers and replicas --- ppcls/data/__init__.py | 35 +++++++++++++++++++++++++++-- ppcls/data/dataloader/pk_sampler.py | 10 ++++++--- ppcls/engine/engine.py | 3 +++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 27f1cbaf..e8d2a1fa 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -14,8 +14,11 @@ import inspect import copy +import random import paddle import numpy as np +import paddle.distributed as dist +from functools import partial from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader from ppcls.utils import logger @@ -66,6 +69,22 @@ def create_operators(params, class_num=None): return ops +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """callback function on each worker subprocess after seeding and before data loading. + + Args: + worker_id (int): Worker id in [0, num_workers - 1] + num_workers (int): Number of subprocesses to use for data loading. + rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`. + seed (int): Random seed + """ + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + + def build_dataloader(config, mode, device, use_dali=False, seed=None): assert mode in [ 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' @@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): seed=seed) class_num = config.get("class_num", None) + epochs = config.get("epochs", None) config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') @@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): shuffle = config_sampler["shuffle"] else: sampler_name = config_sampler.pop("name") + sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args + if "total_epochs" in sampler_argspec: + config_sampler.update({"total_epochs": epochs}) batch_sampler = eval(sampler_name)(dataset, **config_sampler) logger.debug("build batch_sampler({}) success...".format(batch_sampler)) @@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): num_workers = config_loader["num_workers"] use_shared_memory = config_loader["use_shared_memory"] + init_fn = partial( + worker_init_fn, + num_workers=num_workers, + rank=dist.get_rank(), + seed=seed) if seed is not None else None + if batch_sampler is None: data_loader = DataLoader( dataset=dataset, @@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, - collate_fn=batch_collate_fn) + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) else: data_loader = DataLoader( dataset=dataset, @@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): return_list=True, use_shared_memory=use_shared_memory, batch_sampler=batch_sampler, - collate_fn=batch_collate_fn) + collate_fn=batch_collate_fn, + worker_init_fn=init_fn) logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index aa69caa7..cbbc0919 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler): ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list. drop_last (bool, optional): whether to discard the data at the end. Defaults to True. sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". + total_epochs (int, optional): total epochs. Defaults to 0. """ def __init__(self, @@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler): drop_last=True, id_list=None, ratio=None, - sample_method="sample_avg_prob"): + sample_method="sample_avg_prob", + total_epochs=0): super().__init__( dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ @@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler): self.sample_per_id = sample_per_id self.label_dict = defaultdict(list) self.sample_method = sample_method + self.total_epochs = total_epochs for idx, label in enumerate(self.dataset.labels): self.label_dict[label].append(idx) self.label_list = list(self.label_dict) @@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler): def __iter__(self): # shuffle manually, same as DistributedBatchSampler.__iter__ if self.shuffle: - np.random.RandomState(self.epoch + dist.get_rank()).shuffle( - self.label_list) + rank = dist.get_rank() + np.random.RandomState(rank * self.total_epochs + + self.epoch).shuffle(self.label_list) self.epoch += 1 label_per_batch = self.batch_size // self.sample_per_id diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 08ad02e1..3d255a3a 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -119,6 +119,9 @@ class Engine(object): #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) + self.config["DataLoader"].update({ + "epochs": self.config["Global"]["epochs"] + }) # build dataloader if self.mode == 'train': -- GitLab