From 56fded1b8c13175e07e7d2bc4ad976751ea36423 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 15 May 2023 19:24:00 +0800 Subject: [PATCH] [CustomDevice] add inference MP support, PART3 (#53703) --- paddle/fluid/pybind/custom_device_py.cc | 13 ++++++++ .../distributed/auto_parallel/cluster.py | 27 +++++++++------ .../auto_parallel/process_group.py | 13 ++++++++ .../meta_optimizers/sharding_optimizer.py | 1 + .../fleet/utils/hybrid_parallel_util.py | 1 + .../ps/utils/collective_transpiler.py | 33 +++++++++++++++++++ .../distributed/transpiler/collective.py | 32 ++++++++++++++++++ python/paddle/hapi/model.py | 31 +++++++++++++++++ .../optimizer/distributed_fused_lamb.py | 15 +++++++++ 9 files changed, 156 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pybind/custom_device_py.cc b/paddle/fluid/pybind/custom_device_py.cc index d3b4183f2f4..d138115c45e 100644 --- a/paddle/fluid/pybind/custom_device_py.cc +++ b/paddle/fluid/pybind/custom_device_py.cc @@ -29,6 +29,19 @@ namespace pybind { void BindCustomDevicePy(py::module *m_ptr) { auto &m = *m_ptr; // Bind Methods + m.def( + "_get_device_total_memory", + [](const std::string &device_type, int device_id) { + auto place = paddle::platform::CustomPlace( + device_type, + device_id == -1 ? phi::DeviceManager::GetDevice(device_type) + : device_id); + size_t total = 0, free = 0; + phi::DeviceManager::MemoryStats(place, &total, &free); + return total; + }, + py::arg("device_type"), + py::arg("device_id") = -1); m.def( "_get_current_custom_device_stream", [](const std::string &device_type, int device_id) { diff --git a/python/paddle/distributed/auto_parallel/cluster.py b/python/paddle/distributed/auto_parallel/cluster.py index 25324f48833..6a35894900a 100644 --- a/python/paddle/distributed/auto_parallel/cluster.py +++ b/python/paddle/distributed/auto_parallel/cluster.py @@ -883,17 +883,24 @@ def get_default_cluster(json_config=None): assert global_device_count % local_device_count == 0 node_count = int(global_device_count) // local_device_count - gpu_info = paddle.device.cuda.get_device_properties() - assert gpu_info, "Auto parallel just runs on gpu now." - - gpu_name = gpu_info.name - try: - re_result = re.split(r'[ , -]', gpu_name) - gpu_model = re_result[1] - memory = int(re_result[-1][:-2]) - except: - memory = int(gpu_info.total_memory) // (1000**3) + if os.getenv("PADDLE_DISTRI_BACKEND", None) == "xccl": + gpu_name = os.getenv("PADDLE_XCCL_BACKEND", None) gpu_model = gpu_name + memory = int( + paddle.fluid.core._get_device_total_memory(gpu_name) + ) // (1000**3) + else: + gpu_info = paddle.device.cuda.get_device_properties() + assert gpu_info, "Auto parallel just runs on gpu now." + + gpu_name = gpu_info.name + try: + re_result = re.split(r'[ , -]', gpu_name) + gpu_model = re_result[1] + memory = int(re_result[-1][:-2]) + except: + memory = int(gpu_info.total_memory) // (1000**3) + gpu_model = gpu_name logger.info( "Node Count: {}, Local Device Size: {}, GPU Model: {}, GPU Memory: {}GB, World size: {}, EndPoint: {}.".format( diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 8c300cbcd53..b5652353311 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -148,6 +148,11 @@ class ProcessGroup: core.BKCLParallelContext(strategy, place).init_with_ring_id( ring_id ) + elif genv.device_type in core.get_all_custom_device_type(): + place = core.CustomPlace(genv.device_type, genv.device_id) + core.XCCLParallelContext(strategy, place).init_with_ring_id( + ring_id + ) else: raise AssertionError('No CUDA device found') @@ -162,6 +167,14 @@ class ProcessGroup: paddle.set_device( 'xpu:%d' % paddle.distributed.ParallelEnv().dev_id ) + elif genv.device_type in core.get_all_custom_device_type(): + paddle.set_device( + '%s:%d' + % ( + paddle.distributed.ParallelEnv().device_type, + paddle.distributed.ParallelEnv().dev_id, + ), + ) tmp = ( paddle.to_tensor([1], dtype="int32") if in_dygraph_mode() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 29be8a87fa7..59f51108442 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -1040,6 +1040,7 @@ class ShardingOptimizer(MetaOptimizerBase): "c_calc_comm_stream", "c_gen_nccl_id", "c_gen_bkcl_id", + "c_gen_xccl_id", "c_comm_init", 'send_v2', 'recv_v2', diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 04093ebcb35..cd0455c8069 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -167,6 +167,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): place = paddle.CUDAPlace(dev_idx) elif dev in paddle.device.get_all_custom_device_type(): place = paddle.CustomPlace(dev, dev_idx) + dev = 'custom' else: place = eval(f"paddle.{dev.upper()}Place")(dev_idx) diff --git a/python/paddle/distributed/ps/utils/collective_transpiler.py b/python/paddle/distributed/ps/utils/collective_transpiler.py index 5219301d9c7..f5278d05367 100644 --- a/python/paddle/distributed/ps/utils/collective_transpiler.py +++ b/python/paddle/distributed/ps/utils/collective_transpiler.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.distributed.fleet.base.private_helper_function import ( wait_server_ready, ) @@ -204,6 +205,38 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=unique_name.generate('xccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) def _broadcast_params(self): block = self.startup_program.global_block() diff --git a/python/paddle/distributed/transpiler/collective.py b/python/paddle/distributed/transpiler/collective.py index 173f89b91e2..1fb1cf474a7 100644 --- a/python/paddle/distributed/transpiler/collective.py +++ b/python/paddle/distributed/transpiler/collective.py @@ -14,6 +14,7 @@ import os +import paddle from paddle.distributed.fleet.base.private_helper_function import ( wait_server_ready, ) @@ -200,6 +201,37 @@ class Collective: self.op_role_key: OpRole.Forward, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=unique_name.generate('xccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW, + ) + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Forward, + }, + ) + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Forward, + }, + ) def _broadcast_params(self): block = self.startup_program.global_block() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 3f2db38474e..4169902401f 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -184,6 +184,37 @@ def init_communicator( 'ring_id': 0, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + xccl_id_var = block.create_var( + name=fluid.unique_name.generate('xccl_id'), + persistable=True, + type=fluid.core.VarDesc.VarType.RAW, + ) + + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': xccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + }, + ) + + block.append_op( + type='c_comm_init', + inputs={'X': xccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': 0, + }, + ) def prepare_distributed_context(place=None): diff --git a/python/paddle/incubate/optimizer/distributed_fused_lamb.py b/python/paddle/incubate/optimizer/distributed_fused_lamb.py index 9a76db3be3d..a9d9e941d25 100644 --- a/python/paddle/incubate/optimizer/distributed_fused_lamb.py +++ b/python/paddle/incubate/optimizer/distributed_fused_lamb.py @@ -58,6 +58,21 @@ def init_communicator(block, rank, ranks, ring_id): 'ring_id': ring_id, }, ) + elif ( + paddle.distributed.ParallelEnv().device_type + in paddle.device.get_all_custom_device_type() + ): + block.append_op( + type='c_gen_xccl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': local_rank, + 'endpoint': cur_ep, + 'other_endpoints': other_eps, + 'ring_id': ring_id, + }, + ) block.append_op( type='c_comm_init', inputs={'X': comm_id_var}, -- GitLab