From cbb144194cd5c659c3c2055e20e9d9fb5e69824c Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Mon, 12 Dec 2022 10:13:35 +0800 Subject: [PATCH] support sharding in fp16 on xpu, (#48897) * support sharding in fp16 on xpu, change reduce_max to reduce_sum for found nan or inf * update --- .../fleet/meta_parallel/sharding/group_sharded_utils.py | 6 ++++-- .../fleet/test_imperative_auto_mixed_precision.py | 2 +- .../fleet/test_imperative_auto_mixed_precision_for_eager.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index c12381c894e..d845f3b78c6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -220,7 +220,8 @@ def GroupShardedScaler(scaler): temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) - device = "cpu" if optimizer.offload else "gpu" + device = paddle.get_device().split(":")[0] + device = "cpu" if optimizer.offload else device dev_id = ( 0 if device == "cpu" else int(paddle.get_device().split(":")[1]) ) @@ -245,8 +246,9 @@ def GroupShardedScaler(scaler): is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32") paddle.distributed.all_reduce( - is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None + is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None ) + self._found_inf = is_found_inf.numpy()[0] scaler._unscale = MethodType(unscale_method, scaler) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py index d30466d9fc9..2af1f4adec9 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision.py @@ -344,7 +344,7 @@ class TestAmpScaler(unittest.TestCase): scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) - self.assertEqual(scaler._found_inf.numpy() == 1, True) + self.assertEqual(scaler._found_inf.numpy() >= 1, True) for param in model.parameters(): # param not update when tensor contains nan or inf diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py index f688d28b856..b5d36dfebaa 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_imperative_auto_mixed_precision_for_eager.py @@ -343,7 +343,7 @@ class TestAmpScaler(unittest.TestCase): scaled_loss = scaler.scale(loss) scaled_loss.backward() optimize_ops, params_grads = scaler.minimize(optimizer, scaled_loss) - self.assertEqual(scaler._found_inf.numpy() == 1, True) + self.assertEqual(scaler._found_inf.numpy() >= 1, True) for param in model.parameters(): # param not update when tensor contains nan or inf -- GitLab