diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index b3d9e22dba8d244f5e57267a527de0bbdc534996..d812203c041ac1f2caa5b4685bdf48b8766fe1a0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1454,7 +1454,7 @@ All parameter, weight, gradient are variables in Paddle. "number on your machine is %d", dev_id, platform::GetCUDADeviceCount(), platform::GetCUDADeviceCount()); - std::exit(-1); + // std::exit(-1); } } diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 9befdfff04bef14e756f41681f838ea0959d3759..3df1f127b31649c6fdf7feac858ab37161ca1977 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -73,9 +73,30 @@ class CollectiveHelper(object): other_endpoints.remove(current_endpoint) block = program.global_block() if core.is_compiled_with_cuda(): - if rank == 0 and wait_port: - wait_server_ready(other_endpoints) - nccl_id_var = block.create_var( + if not wait_port: + temp_var = block.create_var( + name=unique_name.generate('temp_var'), + dtype=core.VarDesc.VarType.INT32, + persistable=False, + stop_gradient=True) + block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': [temp_var]}, + attrs={ + 'shape': [1], + 'dtype': temp_var.dtype, + 'value': 1, + 'force_cpu': False, + OP_ROLE_KEY: OpRole.Forward + }) + block.append_op( + type='c_allreduce_sum', + inputs={'X': [temp_var]}, + outputs={'Out': [temp_var]}, + attrs={'ring_id': 3, + OP_ROLE_KEY: OpRole.Forward}) + comm_id_var = block.create_var( name=unique_name.generate('nccl_id'), persistable=True, type=core.VarDesc.VarType.RAW) @@ -100,9 +121,7 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) elif core.is_compiled_with_npu(): - endpoint_to_index_map = { - e: idx for idx, e in enumerate(endpoints) - } + endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} block.append_op( type='c_comm_init_hcom', inputs={},