提交 7ff257ea 编写于 作者: H HydrogenSulfate

fix random seed bug for pksampler in DDP

上级 0e5cbd2b
......@@ -18,7 +18,9 @@ from __future__ import division
from collections import defaultdict
import numpy as np
import paddle.distributed as dist
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger
......@@ -94,6 +96,11 @@ class PKSampler(DistributedBatchSampler):
format(diff))
def __iter__(self):
# shuffing label_list manually in distributed environment
if self.nranks > 1:
cur_rank = dist.get_rank()
np.random.RandomState(42 + cur_rank).shuffle(self.label_list)
label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)):
batch_index = []
......
......@@ -126,16 +126,18 @@ class Engine(object):
self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.config["DataLoader"].get('UnLabelTrain', None) is not None:
self.unlabel_train_dataloader = build_dataloader(
self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali)
self.config["DataLoader"], "UnLabelTrain", self.device,
self.use_dali)
else:
self.unlabel_train_dataloader = None
self.iter_per_epoch = len(self.train_dataloader) - 1 if platform.system(
self.iter_per_epoch = len(
self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
if self.config["Global"].get("iter_per_epoch", None):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
self.iter_per_epoch = self.config["Global"].get("iter_per_epoch")
self.iter_per_epoch = self.config["Global"].get(
"iter_per_epoch")
self.iter_per_epoch = self.iter_per_epoch // self.update_freq * self.update_freq
if self.mode == "eval" or (self.mode == "train" and
......@@ -329,6 +331,20 @@ class Engine(object):
)) > 0:
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
# set different seed in different GPU manually in distributed environment
if seed is None:
logger.warning(
"The random seed cannot be None in a distributed environment. Global.seed has been set to 42 by default"
)
self.config["Global"]["seed"] = seed = 42
logger.info(
f"Set random seed to ({seed} + $PADDLE_TRAINER_ID) for different trainer"
)
paddle.seed(seed + dist.get_rank())
np.random.seed(seed + dist.get_rank())
random.seed(seed + dist.get_rank())
# build postprocess for infer
if self.mode == 'infer':
self.preprocess_func = create_operators(self.config["Infer"][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册