diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 5c2f24054f835cda631d50c278fee9ab7d657777..ed16c2296f1e2d36746012c052e2438b97e6527e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -369,8 +369,11 @@ class ShardingOptimizer(MetaOptimizerBase): # FIXME(wangxi): mp should prune duplicated param_grads when calc # amp inf_var & clip global_norm_var - FP16Utils.sync_amp_check_nan_inf(main_block, - [self.mp_ring_id, self.pp_ring_id]) + rings = [self.mp_ring_id, self.pp_ring_id] + # FIXME(wangxi): some problem with NPU found_finite, need sync with DP + if core.is_compiled_with_npu(): + rings += [self.dp_ring_id] + FP16Utils.sync_amp_check_nan_inf(main_block, rings) gradientclip_helper = GradientClipHelper(None) gradientclip_helper.sync_global_norm(