From 426810a254d87d65cbd01a156dc7d90f2a878901 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 23 Aug 2023 14:56:11 +0800 Subject: [PATCH] Fix ZeRO parameter initialization for tensors with `requires_grad=True` (#4138) * Fix ZeRO parameter initialization for tensors with `requires_grad=True` * Simplify detach logic --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/zero/partition_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 093964bc..a0bb37a1 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -241,7 +241,7 @@ def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch. kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) tensor: Tensor = fn(*args, **kwargs) if tensor.is_floating_point(): - tensor = tensor.to(target_fp_dtype) + tensor.data = tensor.data.to(target_fp_dtype) return tensor -- GitLab