diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index c94bd572f05878b97286536c6bc7b0f09db2f9dc..a76a70cdcab3df7f58c500168abcdab7546cf425 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -371,8 +371,11 @@ class ShardingOptimizer(MetaOptimizerBase): # FIXME(wangxi): mp should prune duplicated param_grads when calc # amp inf_var & clip global_norm_var - FP16Utils.sync_amp_check_nan_inf(main_block, - [self.mp_ring_id, self.pp_ring_id]) + rings = [self.mp_ring_id, self.pp_ring_id] + # FIXME(wangxi): some problem with NPU found_finite, need sync with DP + if core.is_compiled_with_npu(): + rings += [self.dp_ring_id] + FP16Utils.sync_amp_check_nan_inf(main_block, rings) gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm(