未验证 提交 548d5522 编写于 作者: R Roc 提交者: GitHub

[KUNLUN]fix cast bf16 (#52246)

上级 d612faf5
......@@ -200,6 +200,9 @@ class HybridParallelClipGrad:
+ paddle.to_tensor([1.0e-6], dtype=paddle.float32),
)
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
# bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu():
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads:
if g is None:
......@@ -209,6 +212,10 @@ class HybridParallelClipGrad:
if g.dtype == paddle.float16:
g.scale_(clip_var_fp16)
elif g.dtype == paddle.bfloat16:
if paddle.is_compiled_with_xpu():
raise NotImplementedError(
"BF16 is not supported on XPU now"
)
g.scale_(clip_var_bf16)
else:
g.scale_(clip_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册