From 25409dccfaf493c218b653a8b860adf7f99420fe Mon Sep 17 00:00:00 2001 From: ronnywang Date: Thu, 8 Jun 2023 09:52:21 +0800 Subject: [PATCH] [CustomDevice] add sharding support (#54384) * [CustomDevice] add sarding support * update --- .../collective/process_group_custom.cc | 37 ++++++++++++ .../collective/process_group_custom.h | 6 ++ paddle/fluid/pybind/custom_device_py.cc | 4 ++ .../group_sharded_optimizer_stage2.py | 57 ++++++++++++++----- .../sharding/group_sharded_stage3.py | 32 ++++++++--- .../sharding/group_sharded_storage.py | 24 +++++--- .../sharding/group_sharded_utils.py | 10 +++- 7 files changed, 140 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 1e4d1df337b..1e80faaff0e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -722,6 +722,43 @@ std::shared_ptr ProcessGroupCustom::Send( false, false); } + +std::shared_ptr ProcessGroupCustom::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ opts.root_rank, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CUSTOM); + std::vector in_wrapper{in_tensor}; + std::vector out_wrapper{*out_tensor}; + return Collective( + in_wrapper, + out_wrapper, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + phi::ccl::CCLComm comm, + const phi::stream::Stream& stream) { + phi::DeviceManager::CCLReduce(device_type_, + input.data(), + output.data(), + input.numel(), + phi::ccl::ToCCLDataType(input.dtype()), + ToCustomCCLRedType(opts.reduce_op), + opts.root_rank, + comm, + stream); + }, + CommType::REDUCE, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupCustom::CreateProcessGroupCustom( const std::shared_ptr& store, diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index 8a0ac6c0a4b..77a83411dd6 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -163,6 +163,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream { std::shared_ptr Recv( std::vector& tensors, int src_rank) override; + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + protected: virtual std::shared_ptr CreateTask( std::vector places, diff --git a/paddle/fluid/pybind/custom_device_py.cc b/paddle/fluid/pybind/custom_device_py.cc index 42addb0445c..0f0caa7fcdd 100644 --- a/paddle/fluid/pybind/custom_device_py.cc +++ b/paddle/fluid/pybind/custom_device_py.cc @@ -29,6 +29,10 @@ namespace pybind { void BindCustomDevicePy(py::module *m_ptr) { auto &m = *m_ptr; // Bind Methods + m.def("_get_device_min_chunk_size", [](const std::string &device_type) { + auto place = paddle::platform::CustomPlace(device_type); + return phi::DeviceManager::GetMinChunkSize(place); + }); m.def( "_get_device_total_memory", [](const std::string &device_type, int device_id) { diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 16af846ebe2..06b94632b74 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -82,8 +82,10 @@ class GroupShardedOptimizerStage2(Optimizer): super().__init__(learning_rate=optim._learning_rate, parameters=params) assert ( - core.is_compiled_with_cuda() or core.is_compiled_with_xpu() - ), "Only GPU and XPU is supported now" + core.is_compiled_with_cuda() + or core.is_compiled_with_xpu() + or (device in core.get_all_custom_device_type()) + ), "Only GPU and XPU and CustomDevice is supported now" # Segmentation information self._dtype_rank_params = ( @@ -371,6 +373,13 @@ class GroupShardedOptimizerStage2(Optimizer): Count the memory size of the parameters corresponding to rank under the corresponding dtype. """ # CUDA alignment 256 bytes + if self._default_device in core.get_all_custom_device_type(): + device_alignment = core.libpaddle._get_device_min_chunk_size( + self._default_device + ) + else: + device_alignment = alignment[self._default_device] + if len(self._rank_buffer_size) == 0: for dtype in self.dtype_rank_params.keys(): if dtype not in self._rank_buffer_size.keys(): @@ -384,11 +393,11 @@ class GroupShardedOptimizerStage2(Optimizer): if not param.trainable: continue size = param._numel() * align[dtype] - remaining = size % alignment[self._default_device] + remaining = size % device_alignment ali = ( 0 if remaining == 0 - else alignment[self._default_device] - remaining + else device_alignment - remaining ) align_ = ali // align[dtype] self._rank_buffer_size[dtype][dst_rank] += ( @@ -439,14 +448,17 @@ class GroupShardedOptimizerStage2(Optimizer): if self.offload: self._optim._master_weights = self._master_params cpu_master_params = list(self._master_params.values()) + if self._default_device in core.get_all_custom_device_type(): + device_alignment = core.libpaddle._get_device_min_chunk_size( + self._default_device + ) + else: + device_alignment = alignment[self._default_device] + for param in cpu_master_params: size = param._numel() * align[Type.fp32.value] - remaining = size % alignment[self.offload_device] - ali = ( - 0 - if remaining == 0 - else alignment[self.offload_device] - remaining - ) + remaining = size % device_alignment + ali = 0 if remaining == 0 else device_alignment - remaining align_ = ali // align[Type.fp32.value] self.offload_buffer_size += param._numel() + align_ self.offload_param2align[param.name] = align_ @@ -528,11 +540,26 @@ class GroupShardedOptimizerStage2(Optimizer): for param in self._local_params: if param.name in self._master_params.keys(): - param.set_value( - self._master_params[param.name] - .cuda(self.dev_id) - .cast(dtype=param.dtype) - ) + if ( + self._default_device + in core.get_all_custom_device_type() + ): + param.set_value( + self._master_params[param.name] + ._copy_to( + paddle.CustomPlace( + self._default_device, self.dev_id + ), + True, + ) + .cast(dtype=param.dtype) + ) + else: + param.set_value( + self._master_params[param.name] + .cuda(self.dev_id) + .cast(dtype=param.dtype) + ) else: self._optim.step() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index b1c47de593b..f85b737a3e5 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -89,7 +89,10 @@ class GroupShardedStage3(nn.Layer): super().__init__() # Default configs - assert core.is_compiled_with_cuda(), "Only support CUDA." + assert core.is_compiled_with_cuda() or ( + device in core.get_all_custom_device_type() + ), "Only support CUDA / CustomDevice." + self._layer = layer self._default_device = device self.__sync_buffers = sync_buffers @@ -243,7 +246,15 @@ class GroupShardedStage3(nn.Layer): else: for param in list(self._unslice_params): param.clear_gradient(False) - tmp_var = param.cuda(DEV_ID) + if ( + self._default_device + in paddle.device.get_all_custom_device_type() + ): + tmp_var = param._copy_to( + paddle.CustomPlace(self._default_device, DEV_ID), True + ) + else: + tmp_var = param.cuda(DEV_ID) if ( tmp_var.dtype == Type.fp32.value @@ -718,10 +729,14 @@ class GroupShardedStage3(nn.Layer): def _param2align(self, param): # CUDA alignment 256 bytes size = param._numel() * align[param.dtype] - remaining = size % alignment[self._default_device] - ali = ( - 0 if remaining == 0 else alignment[self._default_device] - remaining - ) + if self._default_device in core.get_all_custom_device_type(): + device_alignment = core.libpaddle._get_device_min_chunk_size( + self._default_device + ) + else: + device_alignment = alignment[self._default_device] + remaining = size % device_alignment + ali = 0 if remaining == 0 else device_alignment - remaining align_ = ali // align[param.dtype] return align_ @@ -1095,7 +1110,10 @@ def _device2cpu(trans_param, convert_dtype=False): def _cpu2device(param): - tmp_p = param.fw_storage.cuda(DEV_ID) + if DEV in paddle.device.get_all_custom_device_type(): + tmp_p = param.fw_storage._copy_to(paddle.CustomPlace(DEV, DEV_ID), True) + else: + tmp_p = param.fw_storage.cuda(DEV_ID) if ( tmp_p.dtype == Type.fp32.value and param2dtype[param.name] == Type.fp16.value diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py index 44c5995acc7..fb86a27072c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_storage.py @@ -76,11 +76,16 @@ class InternalStorage: ), "Conversion type is not supported now" if self._device != device: - tmp_buffer = ( - cvt_to_device(self.buffer, self.dev_id) - if device in ["gpu", "xpu"] - else self.buffer.cpu() - ) + if device in paddle.device.get_all_custom_device_type(): + tmp_buffer = self.buffer._copy_to( + paddle.CustomPlace(device, self.dev_id), True + ) + else: + tmp_buffer = ( + cvt_to_device(self.buffer, self.dev_id) + if device in ["gpu", "xpu"] + else self.buffer.cpu() + ) for param in self._params: param.clear_gradient(False) @@ -133,8 +138,13 @@ class ParamStorage(InternalStorage): cpu_param_shape.append(p_shape) if convert_gpu: - # buffer convert from cpu to cuda - self.buffer = cvt_to_device(self.buffer, self.dev_id) + if self._device in paddle.device.get_all_custom_device_type(): + self.buffer = self.buffer._copy_to( + paddle.CustomPlace(self._device, self.dev_id), True + ) + else: + # buffer convert from cpu to cuda + self.buffer = cvt_to_device(self.buffer, self.dev_id) self._fill = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 4c47cbfcc1d..8d4bc45a954 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -162,8 +162,14 @@ class GroupShardedClipGrad: # add all reduce to get global norm of distributed params_and_grads dev_id = int(self._device.split(":")[1]) + dev_type = self._device.split(':')[0] if paddle.device.get_device() == "cpu": - global_norm_var = global_norm_var.cuda(dev_id) + if dev_type in paddle.device.get_all_custom_device_type(): + global_norm_var = global_norm_var._copy_to( + paddle.CustomPlace(dev_type, dev_id), True + ) + else: + global_norm_var = global_norm_var.cuda(dev_id) with device_guard(dev_id, self._device.split(":")[0]): paddle.distributed.all_reduce(global_norm_var, group=self._group) @@ -207,6 +213,8 @@ def device_guard(dev_id=0, device="cpu"): paddle.set_device(device) elif device in ["gpu", "xpu"]: paddle.set_device(f"{device}:{dev_id}") + elif device in paddle.device.get_all_custom_device_type(): + paddle.set_device(f"{device}:{dev_id}") try: yield -- GitLab