From 548d5522f388c5ce89b1a653358f2afe1671b312 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 29 Mar 2023 18:53:19 +0800 Subject: [PATCH] [KUNLUN]fix cast bf16 (#52246) --- .../dygraph_optimizer/hybrid_parallel_optimizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index 61aa3d894f0..e609346b29e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -200,7 +200,10 @@ class HybridParallelClipGrad: + paddle.to_tensor([1.0e-6], dtype=paddle.float32), ) clip_var_fp16 = paddle.cast(clip_var, paddle.float16) - clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) + + # bf16 is not supported on XPU now + if not paddle.is_compiled_with_xpu(): + clip_var_bf16 = paddle.cast(clip_var, paddle.bfloat16) for p, g in params_grads: if g is None: continue @@ -209,6 +212,10 @@ class HybridParallelClipGrad: if g.dtype == paddle.float16: g.scale_(clip_var_fp16) elif g.dtype == paddle.bfloat16: + if paddle.is_compiled_with_xpu(): + raise NotImplementedError( + "BF16 is not supported on XPU now" + ) g.scale_(clip_var_bf16) else: g.scale_(clip_var) -- GitLab