未验证 提交 426810a2 编写于 作者: X Xuehai Pan 提交者: GitHub

Fix ZeRO parameter initialization for tensors with `requires_grad=True` (#4138)

* Fix ZeRO parameter initialization for tensors with `requires_grad=True`

* Simplify detach logic

---------
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
上级 9723a879
......@@ -241,7 +241,7 @@ def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.
kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))
tensor: Tensor = fn(*args, **kwargs)
if tensor.is_floating_point():
tensor = tensor.to(target_fp_dtype)
tensor.data = tensor.data.to(target_fp_dtype)
return tensor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册