diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 093964bc90fe4ae66c3e98b6fec866cf23cdaab7..a0bb37a15d5a9c9598afd60f78bb87fc58681e31 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -241,7 +241,7 @@ def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch. kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) tensor: Tensor = fn(*args, **kwargs) if tensor.is_floating_point(): - tensor = tensor.to(target_fp_dtype) + tensor.data = tensor.data.to(target_fp_dtype) return tensor