提交 6cc6540c 编写于 作者: H HydrogenSulfate

add different seed for workers and replicas

上级 b542416d
......@@ -14,8 +14,11 @@
import inspect
import copy
import random
import paddle
import numpy as np
import paddle.distributed as dist
from functools import partial
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger
......@@ -66,6 +69,22 @@ def create_operators(params, class_num=None):
return ops
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
"""callback function on each worker subprocess after seeding and before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1]
num_workers (int): Number of subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`.
seed (int): Random seed
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
......@@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
seed=seed)
class_num = config.get("class_num", None)
epochs = config.get("epochs", None)
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
......@@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
shuffle = config_sampler["shuffle"]
else:
sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
if "total_epochs" in sampler_argspec:
config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler))
......@@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
num_workers = config_loader["num_workers"]
use_shared_memory = config_loader["use_shared_memory"]
init_fn = partial(
worker_init_fn,
num_workers=num_workers,
rank=dist.get_rank(),
seed=seed) if seed is not None else None
if batch_sampler is None:
data_loader = DataLoader(
dataset=dataset,
......@@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=batch_collate_fn)
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
else:
data_loader = DataLoader(
dataset=dataset,
......@@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
return_list=True,
use_shared_memory=use_shared_memory,
batch_sampler=batch_sampler,
collate_fn=batch_collate_fn)
collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader
......@@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler):
ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
total_epochs (int, optional): total epochs. Defaults to 0.
"""
def __init__(self,
......@@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler):
drop_last=True,
id_list=None,
ratio=None,
sample_method="sample_avg_prob"):
sample_method="sample_avg_prob",
total_epochs=0):
super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \
......@@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler):
self.sample_per_id = sample_per_id
self.label_dict = defaultdict(list)
self.sample_method = sample_method
self.total_epochs = total_epochs
for idx, label in enumerate(self.dataset.labels):
self.label_dict[label].append(idx)
self.label_list = list(self.label_dict)
......@@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler):
def __iter__(self):
# shuffle manually, same as DistributedBatchSampler.__iter__
if self.shuffle:
np.random.RandomState(self.epoch + dist.get_rank()).shuffle(
self.label_list)
rank = dist.get_rank()
np.random.RandomState(rank * self.total_epochs +
self.epoch).shuffle(self.label_list)
self.epoch += 1
label_per_batch = self.batch_size // self.sample_per_id
......
......@@ -119,6 +119,9 @@ class Engine(object):
#TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
self.config["DataLoader"].update({
"epochs": self.config["Global"]["epochs"]
})
# build dataloader
if self.mode == 'train':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册