From 669028f0fd5067c9247120cb21fd6e9bea4820a9 Mon Sep 17 00:00:00 2001 From: sdtblck <46172032+sdtblck@users.noreply.github.com> Date: Wed, 21 Apr 2021 23:44:50 +0200 Subject: [PATCH] Fix all Pipeline Module Parameters being sent to cuda:0 (#687) --- deepspeed/runtime/pipe/module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index dcd4be0e..abf90b03 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -151,6 +151,8 @@ class PipelineModule(nn.Module): self.world_group = dist.new_group(ranks=range(dist.get_world_size())) self.global_rank = dist.get_rank(group=self.world_group) self.world_size = dist.get_world_size(group=self.world_group) + self.local_rank = int(os.environ.get("LOCAL_RANK", None)) + assert self.local_rank != None if topology: self._topo = topology @@ -189,7 +191,7 @@ class PipelineModule(nn.Module): #with torch.random.fork_rng(devices=[torch.cuda.current_device()]): self._build() - self.to('cuda') + self.to(f'cuda:{self.local_rank}') self.tied_comms = self._index_tied_modules() self._synchronize_tied_weights() -- GitLab