未验证 提交 97f8a9eb 编写于 作者: D Du Li 提交者: GitHub

fixing a bf16 support issue (#1760)

上级 dbe8ee16
......@@ -1179,8 +1179,11 @@ class PipelineEngine(DeepSpeedEngine):
Returns:
A tensor from torch.zeros() allocated on self.device.
"""
if "dtype" not in kwargs and self.fp16_enabled():
kwargs["dtype"] = torch.half
if "dtype" not in kwargs:
if self.fp16_enabled():
kwargs["dtype"] = torch.half
if self.bfloat16_enabled():
kwargs["dtype"] = torch.bfloat16
return torch.zeros(shape, device=self.device, **kwargs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册