diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7233aecfc3c3d04cead09910a188db712e58f69b..080998e742d7ebd0f8e49a76e5b586bcfce9d01a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1044,7 +1044,7 @@ class DeepSpeedEngine(Module): self.__check_params(self.module, torch.float) # zero.Init() handles device placement of model - if not (self.dont_change_device or is_zero3_model): + if not self.dont_change_device: self.module.to(self.device) # MoE related initialization @@ -1080,7 +1080,7 @@ class DeepSpeedEngine(Module): self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() - if not (self.amp_enabled() or is_zero3_model): + if not self.amp_enabled(): self._broadcast_model() # check if parameters are duplicated in optimizer param_groups