diff --git a/paddle/fluid/pybind/custom_device_py.cc b/paddle/fluid/pybind/custom_device_py.cc index d3b4183f2f4f0376526dc64268898cb141fd5dbc..d138115c45ee0278306950a4be0abdd106a6ad61 100644 --- a/paddle/fluid/pybind/custom_device_py.cc +++ b/paddle/fluid/pybind/custom_device_py.cc @@ -29,6 +29,19 @@ namespace pybind { void BindCustomDevicePy(py::module *m_ptr) { auto &m = *m_ptr; // Bind Methods + m.def( + "_get_device_total_memory", + [](const std::string &device_type, int device_id) { + auto place = paddle::platform::CustomPlace( + device_type, + device_id == -1 ? phi::DeviceManager::GetDevice(device_type) + : device_id); + size_t total = 0, free = 0; + phi::DeviceManager::MemoryStats(place, &total, &free); + return total; + }, + py::arg("device_type"), + py::arg("device_id") = -1); m.def( "_get_current_custom_device_stream", [](const std::string &device_type, int device_id) { diff --git a/python/paddle/distributed/auto_parallel/cluster.py b/python/paddle/distributed/auto_parallel/cluster.py index 25324f4883334e74dd7a9893ca6afdad3cb0ed5d..6a35894900afe24db33a7457e23a1b772308c961 100644 --- a/python/paddle/distributed/auto_parallel/cluster.py +++ b/python/paddle/distributed/auto_parallel/cluster.py @@ -883,17 +883,24 @@ def get_default_cluster(json_config=None): assert global_device_count % local_device_count == 0 node_count = int(global_device_count) // local_device_count - gpu_info = paddle.device.cuda.get_device_properties() - assert gpu_info, "Auto parallel just runs on gpu now." - - gpu_name = gpu_info.name - try: - re_result = re.split(r'[ , -]', gpu_name) - gpu_model = re_result[1] - memory = int(re_result[-1][:-2]) - except: - memory = int(gpu_info.total_memory) // (1000**3) + if os.getenv("PADDLE_DISTRI_BACKEND", None) == "xccl": + gpu_name = os.getenv("PADDLE_XCCL_BACKEND", None) gpu_model = gpu_name + memory = int( + paddle.fluid.core._get_device_total_memory(gpu_name) + ) // (1000**3) + else: + gpu_info = paddle.device.cuda.get_device_properties() + assert gpu_info, "Auto parallel just runs on gpu now." + + gpu_name = gpu_info.name + try: + re_result = re.split(r'[ , -]', gpu_name) + gpu_model = re_result[1] + memory = int(re_result[-1][:-2]) + except: + memory = int(gpu_info.total_memory) // (1000**3) + gpu_model = gpu_name logger.info( "Node Count: {}, Local Device Size: {}, GPU Model: {}, GPU Memory: {}GB, World size: {}, EndPoint: {}.".format( diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 8c300cbcd53b63a0d3d5c579e130f66933c4b9c5..b5652353311a275bf295036f5a59b4071cf6f297 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -148,6 +148,11 @@ class ProcessGroup: core.BKCLParallelContext(strategy, place).init_with_ring_id( ring_id ) + elif genv.device_type in core.get_all_custom_device_type(): + place = core.CustomPlace(genv.device_type, genv.device_id) + core.XCCLParallelContext(strategy, place).init_with_ring_id( + ring_id + ) else: raise AssertionError('No CUDA device found') @@ -162,6 +167,14 @@ class ProcessGroup: paddle.set_device( 'xpu:%d' % paddle.distributed.ParallelEnv().dev_id ) + elif genv.device_type in core.get_all_custom_device_type(): + paddle.set_device( + '%s:%d' + % ( + paddle.distributed.ParallelEnv().device_type, + paddle.distributed.ParallelEnv().dev_id, + ), + ) tmp = ( paddle.to_tensor([1], dtype="int32") if in_dygraph_mode() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 29be8a87fa7d8ddaefb2d33ae1aac7a805e4dbe2..59f511084424f37e639c044bd0ccc297a75b3a01 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -1040,6 +1040,7 @@ class ShardingOptimizer(MetaOptimizerBase): "c_calc_comm_stream", "c_gen_nccl_id", "c_gen_bkcl_id", + "c_gen_xccl_id", "c_comm_init", 'send_v2', 'recv_v2', diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 04093ebcb35fd9a52d9d167a564163875ce61b5b..cd0455c80697ba793e658892bbdbe14d50a4b0d9 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -167,6 +167,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): place = paddle.CUDAPlace(dev_idx) elif dev in paddle.device.get_all_custom_device_type(): place = paddle.CustomPlace(dev, dev_idx) + dev = 'custom' else: place = eval(f"paddle.{dev.upper()}Place")(dev_idx) diff --git a/python/paddle/distributed/ps/utils/collective_transpiler.py b/python/paddle/distributed/ps/utils/collective_transpiler.py index 5219301d9c73829bd2d0b12a474ec43ad5beab78..f5278d05367fd4bcfa254d67c2479c84f3cdbce1 100644 --- a/python/paddle/distributed/ps/utils/collective_transpiler.py +++ b/python/paddle/distributed/ps/utils/collective_transpiler.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.distributed.fleet.base.private_helper_function import ( wait_server_ready, ) @@ -204,6 +205,38 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=unique_name.generate('xccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) def _broadcast_params(self): block = self.startup_program.global_block() diff --git a/python/paddle/distributed/transpiler/collective.py b/python/paddle/distributed/transpiler/collective.py index 173f89b91e24b70620cd0bd56c062b0781c12c6a..1fb1cf474a70173c17b3ec2d8ac2c1dd10492908 100644 --- a/python/paddle/distributed/transpiler/collective.py +++ b/python/paddle/distributed/transpiler/collective.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.distributed.fleet.base.private_helper_function import ( wait_server_ready, ) @@ -200,6 +201,37 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=unique_name.generate('xccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) def _broadcast_params(self): block = self.startup_program.global_block() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 3f2db38474eb3c257093fc727fb56e2df8f00dab..4169902401fcf5ac22bdbb3b68cf2b85e1e05410 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -184,6 +184,37 @@ def init_communicator( 'ring_id': 0, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=fluid.unique_name.generate('xccl_id'), + persistable=True, + type=fluid.core.VarDesc.VarType.RAW, + ) + + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + }, + ) + + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': 0, + }, + ) def prepare_distributed_context(place=None): diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 9a76db3be3d9256df5de98d0bc199966cfcc65ce..a9d9e941d25cc943708ff91c8cd979d50327e6bc 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -58,6 +58,21 @@ def init_communicator(block, rank, ranks, ring_id): 'ring_id': ring_id, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': local_rank, + 'endpoint': cur_ep, + 'other_endpoints': other_eps, + 'ring_id': ring_id, + }, + ) block.append_op( type='c_comm_init', inputs={'X': comm_id_var},