From 3354cff38b7f5f1cba45ffc16a18bbf3c933861a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 31 Jan 2020 16:53:59 -0800 Subject: [PATCH] add deepspeed exec and dataloader --- bin/deepspeed | 1 + bin/ds | 6 ++ deepspeed/pt/deepspeed_dataloader.py | 92 ++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 120000 bin/deepspeed create mode 100755 bin/ds create mode 100755 deepspeed/pt/deepspeed_dataloader.py diff --git a/bin/deepspeed b/bin/deepspeed new file mode 120000 index 00000000..6b768564 --- /dev/null +++ b/bin/deepspeed @@ -0,0 +1 @@ +ds \ No newline at end of file diff --git a/bin/ds b/bin/ds new file mode 100755 index 00000000..47efea32 --- /dev/null +++ b/bin/ds @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +from deepspeed.pt.deepspeed_run import main + +if __name__ == '__main__': + main() diff --git a/deepspeed/pt/deepspeed_dataloader.py b/deepspeed/pt/deepspeed_dataloader.py new file mode 100755 index 00000000..7eca59d5 --- /dev/null +++ b/deepspeed/pt/deepspeed_dataloader.py @@ -0,0 +1,92 @@ +''' +Copyright 2019 The Microsoft DeepSpeed Team +''' + +import torch +import logging +from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm + + +class DeepSpeedDataSource(object): + def __init__(self, filenames): + all_lines = [] + for filename in filenames: + logging.info("Start reading file %s" % filename) + with open(filename, "r") as f: + for i, line in enumerate(tqdm(f)): + all_lines.append(line.strip()) + + self.all_lines = all_lines + self.len = len(self.all_lines) + + def __len__(self): + return self.len + + +class DeepSpeedDataLoader(object): + def __init__(self, + dataset, + batch_size, + pin_memory, + local_rank, + tput_timer, + collate_fn=None, + num_local_io_workers=None, + data_sampler=None): + self.tput_timer = tput_timer + self.batch_size = batch_size + + if local_rank >= 0: + if data_sampler is None: + data_sampler = DistributedSampler(dataset) + device_count = 1 + else: + if data_sampler is None: + data_sampler = RandomSampler(dataset) + device_count = torch.cuda.device_count() + batch_size *= device_count + + if num_local_io_workers is None: + num_local_io_workers = 2 * device_count + + self.num_local_io_workers = num_local_io_workers + self.data_sampler = data_sampler + self.dataset = dataset + self.collate_fn = collate_fn + self.device_count = device_count + self.batch_size = batch_size + self.pin_memory = pin_memory + self.len = len(self.data_sampler) + self.data = None + + def __iter__(self): + self._create_dataloader() + return self + + def __len__(self): + return self.len + + def __next__(self): + if self.tput_timer: + self.tput_timer.start() + return next(self.data) + + def _create_dataloader(self): + if self.collate_fn is None: + self.dataloader = DataLoader(self.dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + sampler=self.data_sampler, + num_workers=self.num_local_io_workers) + else: + self.dataloader = DataLoader(self.dataset, + batch_size=self.batch_size, + pin_memory=self.pin_memory, + sampler=self.data_sampler, + collate_fn=self.collate_fn, + num_workers=self.num_local_io_workers) + self.data = (x for x in self.dataloader) + + return self.dataloader -- GitLab