未验证 提交 f1d2a15b 编写于 作者: M Mayank Mishra 提交者: GitHub

better eval sampler (#2907)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 541e423a
......@@ -1673,9 +1673,6 @@ class DeepSpeedEngine(Module):
or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset")
if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.SequentialSampler(dataset)
if batch_size is None:
batch_size = self.train_micro_batch_size_per_gpu()
......@@ -1694,6 +1691,14 @@ class DeepSpeedEngine(Module):
data_parallel_world_size = self.mpu.get_data_parallel_world_size()
data_parallel_rank = self.mpu.get_data_parallel_rank()
if data_sampler is None and (route == ROUTE_PREDICT or route == ROUTE_EVAL):
data_sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=data_parallel_world_size,
rank=data_parallel_rank,
shuffle=False,
)
deepspeed_dataloader_config = {}
if self.curriculum_learning_enabled():
deepspeed_dataloader_config = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册