未验证 提交 548d5522 编写于 作者: R Roc 提交者: GitHub

[KUNLUN]fix cast bf16 (#52246)

上级 d612faf5
...@@ -200,7 +200,10 @@ class HybridParallelClipGrad: ...@@ -200,7 +200,10 @@ class HybridParallelClipGrad:
+ paddle.to_tensor([1.0e-6], dtype=paddle.float32), + paddle.to_tensor([1.0e-6], dtype=paddle.float32),
) )
clip_var_fp16 = paddle.cast(clip_var, paddle.float16) clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
# bf16 is not supported on XPU now
if not paddle.is_compiled_with_xpu():
clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16)
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
...@@ -209,6 +212,10 @@ class HybridParallelClipGrad: ...@@ -209,6 +212,10 @@ class HybridParallelClipGrad:
if g.dtype == paddle.float16: if g.dtype == paddle.float16:
g.scale_(clip_var_fp16) g.scale_(clip_var_fp16)
elif g.dtype == paddle.bfloat16: elif g.dtype == paddle.bfloat16:
if paddle.is_compiled_with_xpu():
raise NotImplementedError(
"BF16 is not supported on XPU now"
)
g.scale_(clip_var_bf16) g.scale_(clip_var_bf16)
else: else:
g.scale_(clip_var) g.scale_(clip_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册