未验证 提交 06938835 编写于 作者: M Ma, Guokai 提交者: GitHub

Support fp32 gradaccum for bf16 model (#2566)

* allow bf16 model with fp32 gradient accumulation datatype

* allow fp32 gradient accumulation and bfloat16 model in amp mode

* alternative fix for grad accumulation type mismatch.  In the case of zero optimizer we should have grad accum type == model data type
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 2d8f3f56
......@@ -807,7 +807,7 @@ class DeepSpeedEngine(Module):
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype == None:
if model_dtype == torch.bfloat16:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册