未验证 提交 669028f0 编写于 作者: S sdtblck 提交者: GitHub

Fix all Pipeline Module Parameters being sent to cuda:0 (#687)

上级 0b80ad06
......@@ -151,6 +151,8 @@ class PipelineModule(nn.Module):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
self.global_rank = dist.get_rank(group=self.world_group)
self.world_size = dist.get_world_size(group=self.world_group)
self.local_rank = int(os.environ.get("LOCAL_RANK", None))
assert self.local_rank != None
if topology:
self._topo = topology
......@@ -189,7 +191,7 @@ class PipelineModule(nn.Module):
#with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
self._build()
self.to('cuda')
self.to(f'cuda:{self.local_rank}')
self.tied_comms = self._index_tied_modules()
self._synchronize_tied_weights()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册