diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index df775247c8c9e53bdc5c6314a81f3ea62d6148a5..1f1960b17007fdab5dc21397e92b87607898af72 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -379,8 +379,9 @@ class ShardingOptimizer(MetaOptimizerBase): 'w') as f: f.writelines(str(main_block.program)) - # GPU and NPU need to wait server ready - self._wait() + # GPU need to wait server ready, GPU and NPU is Layered connection + if not core.is_compiled_with_npu(): + self._wait() return optimize_ops, params_grads def _init_pair_comm(self, pair, ring_id):