提交 6cc6540c 编写于 作者: H HydrogenSulfate

add different seed for workers and replicas

上级 b542416d
...@@ -14,8 +14,11 @@ ...@@ -14,8 +14,11 @@
import inspect import inspect
import copy import copy
import random
import paddle import paddle
import numpy as np import numpy as np
import paddle.distributed as dist
from functools import partial
from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader from paddle.io import DistributedBatchSampler, BatchSampler, DataLoader
from ppcls.utils import logger from ppcls.utils import logger
...@@ -66,6 +69,22 @@ def create_operators(params, class_num=None): ...@@ -66,6 +69,22 @@ def create_operators(params, class_num=None):
return ops return ops
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
"""callback function on each worker subprocess after seeding and before data loading.
Args:
worker_id (int): Worker id in [0, num_workers - 1]
num_workers (int): Number of subprocesses to use for data loading.
rank (int): Rank of process in distributed environment. If in non-distributed environment, it is a constant number `0`.
seed (int): Random seed
"""
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
def build_dataloader(config, mode, device, use_dali=False, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in [ assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
...@@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -82,6 +101,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
seed=seed) seed=seed)
class_num = config.get("class_num", None) class_num = config.get("class_num", None)
epochs = config.get("epochs", None)
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name') dataset_name = config_dataset.pop('name')
...@@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -103,6 +123,9 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
shuffle = config_sampler["shuffle"] shuffle = config_sampler["shuffle"]
else: else:
sampler_name = config_sampler.pop("name") sampler_name = config_sampler.pop("name")
sampler_argspec = inspect.getargspec(eval(sampler_name).__init__).args
if "total_epochs" in sampler_argspec:
config_sampler.update({"total_epochs": epochs})
batch_sampler = eval(sampler_name)(dataset, **config_sampler) batch_sampler = eval(sampler_name)(dataset, **config_sampler)
logger.debug("build batch_sampler({}) success...".format(batch_sampler)) logger.debug("build batch_sampler({}) success...".format(batch_sampler))
...@@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -131,6 +154,12 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
num_workers = config_loader["num_workers"] num_workers = config_loader["num_workers"]
use_shared_memory = config_loader["use_shared_memory"] use_shared_memory = config_loader["use_shared_memory"]
init_fn = partial(
worker_init_fn,
num_workers=num_workers,
rank=dist.get_rank(),
seed=seed) if seed is not None else None
if batch_sampler is None: if batch_sampler is None:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
...@@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -141,7 +170,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
batch_size=batch_size, batch_size=batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
collate_fn=batch_collate_fn) collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
else: else:
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
...@@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): ...@@ -150,7 +180,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
return_list=True, return_list=True,
use_shared_memory=use_shared_memory, use_shared_memory=use_shared_memory,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=batch_collate_fn) collate_fn=batch_collate_fn,
worker_init_fn=init_fn)
logger.debug("build data_loader({}) success...".format(data_loader)) logger.debug("build data_loader({}) success...".format(data_loader))
return data_loader return data_loader
...@@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler): ...@@ -38,6 +38,7 @@ class PKSampler(DistributedBatchSampler):
ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list. ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True. drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
total_epochs (int, optional): total epochs. Defaults to 0.
""" """
def __init__(self, def __init__(self,
...@@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler): ...@@ -48,7 +49,8 @@ class PKSampler(DistributedBatchSampler):
drop_last=True, drop_last=True,
id_list=None, id_list=None,
ratio=None, ratio=None,
sample_method="sample_avg_prob"): sample_method="sample_avg_prob",
total_epochs=0):
super().__init__( super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last) dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \ assert batch_size % sample_per_id == 0, \
...@@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler): ...@@ -58,6 +60,7 @@ class PKSampler(DistributedBatchSampler):
self.sample_per_id = sample_per_id self.sample_per_id = sample_per_id
self.label_dict = defaultdict(list) self.label_dict = defaultdict(list)
self.sample_method = sample_method self.sample_method = sample_method
self.total_epochs = total_epochs
for idx, label in enumerate(self.dataset.labels): for idx, label in enumerate(self.dataset.labels):
self.label_dict[label].append(idx) self.label_dict[label].append(idx)
self.label_list = list(self.label_dict) self.label_list = list(self.label_dict)
...@@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler): ...@@ -98,8 +101,9 @@ class PKSampler(DistributedBatchSampler):
def __iter__(self): def __iter__(self):
# shuffle manually, same as DistributedBatchSampler.__iter__ # shuffle manually, same as DistributedBatchSampler.__iter__
if self.shuffle: if self.shuffle:
np.random.RandomState(self.epoch + dist.get_rank()).shuffle( rank = dist.get_rank()
self.label_list) np.random.RandomState(rank * self.total_epochs +
self.epoch).shuffle(self.label_list)
self.epoch += 1 self.epoch += 1
label_per_batch = self.batch_size // self.sample_per_id label_per_batch = self.batch_size // self.sample_per_id
......
...@@ -119,6 +119,9 @@ class Engine(object): ...@@ -119,6 +119,9 @@ class Engine(object):
#TODO(gaotingquan): support rec #TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num}) self.config["DataLoader"].update({"class_num": class_num})
self.config["DataLoader"].update({
"epochs": self.config["Global"]["epochs"]
})
# build dataloader # build dataloader
if self.mode == 'train': if self.mode == 'train':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册