未验证 提交 d39c311f 编写于 作者: O Olatunji Ruwase 提交者: GitHub

DS init should not broadcast or move zero.Init models (#3611)

上级 736bf185
......@@ -1027,20 +1027,22 @@ class DeepSpeedEngine(Module):
def _configure_distributed_model(self, model):
self._set_client_model(model)
is_zero3_model = self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()])
if self.fp16_enabled():
if self.zero_optimization_partition_weights() and any(
[hasattr(param, "ds_id") for param in self.module.parameters()]):
if is_zero3_model:
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
hasattr(param, 'ds_id') for param in self.module.parameters()):
if is_zero3_model:
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
self.__check_params(self.module, torch.float)
if not self.dont_change_device:
# zero.Init() handles device placement of model
if not (self.dont_change_device or is_zero3_model):
self.module.to(self.device)
# MoE related initialization
......@@ -1076,7 +1078,7 @@ class DeepSpeedEngine(Module):
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
if not self.amp_enabled():
if not (self.amp_enabled() or is_zero3_model):
self._broadcast_model()
# check if parameters are duplicated in optimizer param_groups
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册