提交 7ff257ea 编写于 作者: H HydrogenSulfate

fix random seed bug for pksampler in DDP

上级 0e5cbd2b
...@@ -18,7 +18,9 @@ from __future__ import division ...@@ -18,7 +18,9 @@ from __future__ import division
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import paddle.distributed as dist
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ppcls.utils import logger from ppcls.utils import logger
...@@ -94,6 +96,11 @@ class PKSampler(DistributedBatchSampler): ...@@ -94,6 +96,11 @@ class PKSampler(DistributedBatchSampler):
format(diff)) format(diff))
def __iter__(self): def __iter__(self):
# shuffing label_list manually in distributed environment
if self.nranks > 1:
cur_rank = dist.get_rank()
np.random.RandomState(42 + cur_rank).shuffle(self.label_list)
label_per_batch = self.batch_size // self.sample_per_id label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)): for _ in range(len(self)):
batch_index = [] batch_index = []
......
...@@ -126,16 +126,18 @@ class Engine(object): ...@@ -126,16 +126,18 @@ class Engine(object):
self.config["DataLoader"], "Train", self.device, self.use_dali) self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.config["DataLoader"].get('UnLabelTrain', None) is not None: if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.unlabel_train_dataloader = build_dataloader( self.unlabel_train_dataloader = build_dataloader(
self.config["DataLoader"], "UnLabelTrain", self.device, self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali) self.use_dali)
else: else:
self.unlabel_train_dataloader = None self.unlabel_train_dataloader = None
self.iter_per_epoch = len(self.train_dataloader) - 1 if platform.system( self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader) ) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None): if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get("iter_per_epoch") self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq
if self.mode == "eval" or (self.mode == "train" and if self.mode == "eval" or (self.mode == "train" and
...@@ -329,6 +331,20 @@ class Engine(object): ...@@ -329,6 +331,20 @@ class Engine(object):
)) > 0: )) > 0:
self.train_loss_func = paddle.DataParallel( self.train_loss_func = paddle.DataParallel(
self.train_loss_func) self.train_loss_func)
# set different seed in different GPU manually in distributed environment
if seed is None:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer"
)
paddle.seed(seed + dist.get_rank())
np.random.seed(seed + dist.get_rank())
random.seed(seed + dist.get_rank())
# build postprocess for infer # build postprocess for infer
if self.mode == 'infer': if self.mode == 'infer':
self.preprocess_func = create_operators(self.config["Infer"][ self.preprocess_func = create_operators(self.config["Infer"][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册