diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 484cd223949c6cb461ce760f6240300f46f06885..8bbf42b72f2d6d8cb263d1e099044d8baf657a8c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -35,6 +35,7 @@ from paddle.distributed.collective import _get_global_group from .sharding_utils import Type, ShardingClipGrad, device_guard from ..pp_utils.utils import _all_gather +from ...utils.internal_storage import GradStorage # CUDA alignment 256 bytes alignment = {"gpu": 256, } @@ -69,6 +70,7 @@ class ShardingStage3(nn.Layer): group=None, sync_buffers=False, device="gpu", + segment_size=2**15, pertrain_sync_models=True, accumulate_grads=False, offload=False, @@ -83,6 +85,8 @@ class ShardingStage3(nn.Layer): self._accumulate_grads = accumulate_grads self._offload = offload self._sync_comm = sync_comm + # segmentation size + self._segment_size = segment_size if not offload else 0 global DEV DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device( @@ -107,7 +111,10 @@ class ShardingStage3(nn.Layer): self._param2buffer_size = dict() # {param.name: size} self._param2buffer = dict( ) # {param.name: [(start0, end0),(start1, end1), ...]} - self._trainable_params = dict() # {layer.name: [trainable_params]} + self._trainable_params = dict() # {id(layer): [trainable_params]} + self._unslice_params = set() # param's numel <= segment_size + self._unslice_params2align = dict() # {param.name: param's align} + self._grad_storages = dict() # {param.dtype: GradStorage} assert not isinstance( optimizer, list), "Multiple optimizers are not supported now." @@ -131,10 +138,13 @@ class ShardingStage3(nn.Layer): self._segment_rank_params(self._layer) + # Add unslice params to master_weight in fp16 + self._handle_unslice_params() + # In the first step, record the execution order of the layer self._order_tracer = OrderedDict() self._order_tracer["order"] = 0 - self._order_tracer["layer"] = [] + self._order_tracer["layer"] = list() # Register task flow self._task_flow = TaskFlow() @@ -168,8 +178,10 @@ class ShardingStage3(nn.Layer): def _clear_gradients(self): assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) + # 1.Handle param's slice trainable_params = list( - filter(lambda x: x.trainable, current_layer_params)) + filter(lambda p: p.trainable and p not in self._unslice_params, + current_layer_params)) for param in trainable_params: assert hasattr( param, "fw_storage" @@ -178,6 +190,9 @@ class ShardingStage3(nn.Layer): param.fw_storage.clear_gradient(False) param.fw_storage._gradient_set_empty(False) param.bw_storage._clear() + # 2.Handle unslice param + for grad_storage in self._grad_storages.values(): + grad_storage.buffer.zero_() # Update param memery slice def _update_params_slice(self): @@ -185,20 +200,25 @@ class ShardingStage3(nn.Layer): if not isinstance(self._optim._param_groups[0], dict): slice_params = [param.fw_storage for param in update_list] - self._optim._parameter_list = slice_params - self._optim._param_groups = slice_params + self._optim._parameter_list = slice_params + list( + self._unslice_params) + self._optim._param_groups = slice_params + list( + self._unslice_params) else: params_name_list = list(map(lambda p: p.name, update_list)) + fw_storage_name_list = list( + map(lambda p: p.fw_storage.name, update_list)) for param_group in self._optim._param_groups: - slice_p = [] + p_group = [] for p in param_group['params']: if p.name in params_name_list: - assert hasattr( - p, "fw_storage" - ), "Find {} don't have fw_storage attribute.".format( - p.name) - slice_p.append(p.fw_storage) - param_group['params'] = slice_p + p_group.append(p.fw_storage) + elif p.name in fw_storage_name_list: + p_group.append(update_list[fw_storage_name_list.index( + p.name)].fw_storage) + elif p in self._unslice_params: + p_group.append(p) + param_group['params'] = p_group def forward(self, *inputs, **kwargs): """ @@ -213,6 +233,32 @@ class ShardingStage3(nn.Layer): return fw + def _handle_unslice_params(self): + buffer_size = dict() + buffer_size[Type.fp32.value] = 0 + buffer_size[Type.fp16.value] = 0 + for param in self._unslice_params: + # Updata optimizer master weights + if param.dtype == Type.fp16.value and not self._offload: + self._optim._master_weights[param.name] = paddle.cast( + param, Type.fp32.value) + param2dtype[param.name] = param.dtype + p_align = self._param2align(param) + self._unslice_params2align[param.name] = p_align + buffer_size[param.dtype] += param._numel() + p_align + + # Create unslice_params'grad + for param in sorted(list(self._unslice_params), key=lambda p: p.name): + if param.dtype not in self._grad_storages.keys(): + self._grad_storages[param.dtype] = GradStorage( + buffer_size[param.dtype], + dtype=param.dtype, + device=self._default_device, + destination=self._rank, + parm2align=self._unslice_params2align) + self._grad_storages[param.dtype].add_grad( + param, self._unslice_params2align[param.name]) + def _segment_rank_params(self, layer, name="last_layer"): """ Flatten parameters according to layer. @@ -233,24 +279,22 @@ class ShardingStage3(nn.Layer): def _add_manage_info(trainable_param): return _PartitionParam(trainable_param) - trainable_params = list( - filter(lambda x: x.trainable, current_layer_params)) + current_params = list() + for p in current_layer_params: + if p.trainable and p._numel() > self._segment_size: + current_params.append(_add_manage_info(p)) + elif p.trainable: + self._unslice_params.add(_UnsliceParam(p)) + assert id(layer) not in self._trainable_params.keys() - self._trainable_params[id(layer)] = list( - map(_add_manage_info, trainable_params)) + self._trainable_params[id(layer)] = current_params for param in self._trainable_params[id(layer)]: if param.name in self._param2buffer.keys(): continue self._param2buffer[param.name] = [] # 1.Params alignment - offset = 0 - # CUDA alignment 256 bytes - size = param._numel() * align[param.dtype] - remaining = size % alignment[self._default_device] - ali = 0 if remaining == 0 else alignment[ - self._default_device] - remaining - align_ = ali // align[param.dtype] + align_ = self._param2align(param) offset = align_ + param._numel() buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - ( @@ -379,7 +423,9 @@ class ShardingStage3(nn.Layer): assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) trainable_params = list( - filter(lambda x: x.trainable, current_layer_params)) + filter(lambda p: p.trainable and p not in self._unslice_params, + current_layer_params)) + # 1.Handle param's slice for param in trainable_params: assert hasattr( param, @@ -396,6 +442,19 @@ class ShardingStage3(nn.Layer): assert param.fw_storage.grad is None param.fw_storage._copy_gradient_from(param.bw_storage) update_list.append(param) + + # 2.Handle unslice param + for grad_storage in self._grad_storages.values(): + grad_storage.buffer.scale_(scale=self._world_size_scaling) + dist.all_reduce( + tensor=grad_storage.buffer, + group=self._group, + use_calc_stream=True) + dist.wait( + tensor=grad_storage.buffer, + group=self._group, + use_calc_stream=True) + return update_list def get_all_parameters(self, convert2cpu=False): @@ -405,7 +464,8 @@ class ShardingStage3(nn.Layer): assert len(self._trainable_params.keys()) > 0 current_layer_params = self._layer.parameters(include_sublayers=True) trainable_params = list( - filter(lambda x: x.trainable, current_layer_params)) + filter(lambda p: p.trainable and p not in self._unslice_params, + current_layer_params)) t_flow = _allgather_buffer( trainable_params, self._group, @@ -415,7 +475,7 @@ class ShardingStage3(nn.Layer): offload=self._offload, convert2cpu=convert2cpu) if convert2cpu: - for param in current_layer_params: + for param in trainable_params: t_flow.full_param[param.name]._share_buffer_to(param) self._optim._parameter_list = self._ori_parameter_list @@ -424,7 +484,8 @@ class ShardingStage3(nn.Layer): def _register_backward_hooks(self): current_layer_params = self._layer.parameters(include_sublayers=True) trainable_params = list( - filter(lambda x: x.trainable, current_layer_params)) + filter(lambda p: p.trainable and p not in self._unslice_params, + current_layer_params)) for param in trainable_params: allreduce_function = self._get_allreduce_fn(param) @@ -435,42 +496,36 @@ class ShardingStage3(nn.Layer): def reduce(*_): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] - with paddle.amp.auto_cast(enable=False): - if not self._accumulate_grads: - full_grad.scale_(scale=self._world_size_scaling) - # Only support sync allreduce current rank's layer now - dist.all_reduce( - tensor=full_grad, - group=self._group, - use_calc_stream=True) - dist.wait( - tensor=full_grad, - group=self._group, - use_calc_stream=True) + if not self._accumulate_grads: + full_grad.scale_(scale=self._world_size_scaling) + # Only support sync allreduce current rank's layer now + dist.all_reduce( + tensor=full_grad, group=self._group, use_calc_stream=True) + dist.wait( + tensor=full_grad, group=self._group, use_calc_stream=True) - start, end = self._param2buffer[param.name][self._rank] - if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( - ).get_tensor()._is_initialized(): - param.bw_storage = core.VarBase( - full_grad._slice(start, end)).detach().clone() - if self._offload: - param.bw_storage = _device2cpu(param.bw_storage, - True) + start, end = self._param2buffer[param.name][self._rank] + if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( + ).get_tensor()._is_initialized(): + param.bw_storage = core.VarBase( + full_grad._slice(start, end)).detach().clone() + if self._offload: + param.bw_storage = _device2cpu(param.bw_storage, True) + else: + if self._offload: + cpu_grad = _device2cpu( + core.VarBase(full_grad._slice(start, end)) + .detach().clone(), True) + param.bw_storage = paddle.add(param.bw_storage, + cpu_grad) else: - if self._offload: - cpu_grad = _device2cpu( - core.VarBase(full_grad._slice(start, end)) - .detach().clone(), True) - param.bw_storage = paddle.add(param.bw_storage, - cpu_grad) - else: - # param.bw_storage.add_( - # core.VarBase(full_grad._slice(start, end)) - # .detach().clone()) - param.bw_storage = paddle.add( - param.bw_storage, - core.VarBase(full_grad._slice( - start, end)).detach().clone()) + # param.bw_storage.add_( + # core.VarBase(full_grad._slice(start, end)) + # .detach().clone()) + param.bw_storage = paddle.add( + param.bw_storage, + core.VarBase(full_grad._slice(start, end)).detach( + ).clone()) param.clear_gradient(False) param._gradient_set_empty(False) tmp_var = self._task_flow.full_grad.pop(param.name) @@ -493,6 +548,15 @@ class ShardingStage3(nn.Layer): return reduce + def _param2align(self, param): + # CUDA alignment 256 bytes + size = param._numel() * align[param.dtype] + remaining = size % alignment[self._default_device] + ali = 0 if remaining == 0 else alignment[ + self._default_device] - remaining + align_ = ali // align[param.dtype] + return align_ + def _redefine_opt_step(self): params_slice_func = self._update_params_slice opt_step = self._optim.step @@ -679,14 +743,13 @@ def _wait_layer(trainable_params, group, use_calc_stream, offload=False): + paddle.device.cuda.synchronize() for param in trainable_params: if param.status == "all": param.use_count += 1 continue if param.name in task_flow.full_param.keys(): full_param = task_flow.full_param[param.name] - with paddle.amp.auto_cast(enable=False): - paddle.device.cuda.synchronize() core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( param) param.fw_storage._clear() @@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params, full_param = _all_gather( param.fw_storage, group, use_calc_stream=use_calc_stream) - # Allgather current layer in the 1st step + # Allgather current layer in the 1st step synchronously if sync_wait: with paddle.amp.auto_cast(enable=False): dist.wait( @@ -774,6 +837,12 @@ def _PartitionParam(param): return param +def _UnsliceParam(param): + if not hasattr(param, "unslice"): + setattr(param, "unslice", True) + return param + + def _VarBaseWrapper(param): varbase = param.fw_storage tmp_param = ParamBase( diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 9c30ff5a45075ae423d6a46ef328e3b6523fbd5b..ee281a0a044f4d9e371d81b00df933f0f22ac26e 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -57,12 +57,15 @@ class ShardingClipGrad: @imperative_base.no_grad def _dygraph_clip(self, params_grads): - sum_square_fp16 = [] - sum_square_fp32 = [] + sum_square_fp32, sum_square_fp16 = [], [] + unslice_params_fp32, unslice_params_fp16 = [], [] for p, g in params_grads: + p_slice = True # using for slice parameter in sharding stage3 if g is None or getattr(p, 'need_clip', True) is False: continue + if hasattr(p, "unslice"): + p_slice = False merge_grad = g if g.type == core.VarDesc.VarType.SELECTED_ROWS: @@ -72,9 +75,11 @@ class ShardingClipGrad: sum_square = layers.reduce_sum(square) if p.dtype == paddle.float16: - sum_square_fp16.append(sum_square) + if p_slice: sum_square_fp16.append(sum_square) + else: unslice_params_fp16.append(sum_square) elif p.dtype == paddle.float32: - sum_square_fp32.append(sum_square) + if p_slice: sum_square_fp32.append(sum_square) + else: unslice_params_fp32.append(sum_square) # global norm of non-distributed FP16 params_and_grads if len(sum_square_fp16) == 0: @@ -85,12 +90,28 @@ class ShardingClipGrad: global_norm_fp16 = paddle.cast( global_norm_fp16, dtype=paddle.float32) + # global norm of non-distributed FP16 params_and_grads for slice parameter + if len(unslice_params_fp16) == 0: + global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32) + else: + global_unslice_fp16 = layers.concat(unslice_params_fp16) + global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16) + global_unslice_fp16 = paddle.cast( + global_unslice_fp16, dtype=paddle.float32) + # global norm of non-distributed FP32 params_and_grads global_norm_fp32 = layers.concat(sum_square_fp32) if len( sum_square_fp32) != 0 else paddle.to_tensor( [0.], dtype=paddle.float32) global_norm_fp32 = layers.reduce_sum(global_norm_fp32) + # global norm of non-distributed FP32 params_and_grads for slice parameter + global_unslice_fp32 = layers.concat(unslice_params_fp32) if len( + unslice_params_fp32) != 0 else paddle.to_tensor( + [0.], dtype=paddle.float32) + global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32) + global_unslice_var = global_unslice_fp16 + global_unslice_fp32 + global_norm_var = global_norm_fp16 + global_norm_fp32 # add all reduce to get global norm of distributed params_and_grads @@ -98,6 +119,7 @@ class ShardingClipGrad: with device_guard(dev_id, "gpu"): paddle.distributed.all_reduce(global_norm_var, group=self._group) + global_norm_var += global_unslice_var global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index 9b218bf13027a0bac7e55e4b146351bafdedfb7a..9bb1f85f327c3b248f7d050484a45736a61e4f77 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -145,6 +145,10 @@ def train_mlp(model, loss = paddle.nn.functional.cross_entropy( input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + + if batch_size == 20: + avg_loss = avg_loss / 5 + if not use_pure_fp16: avg_loss.backward() else: @@ -215,7 +219,7 @@ def test_stage2_stage3(): stage3_params[i].numpy(), stage3_params_add[i].numpy(), rtol=1e-6, - atol=1e-6) + atol=1e-4) # fp16 stage2_params = train_mlp( diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py index 4d4b4c02068aa71c51f02dd8df74fac14e65fafc..aa440549cf1474f40e2498add6eb59830606baba 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py @@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler epoch = 10 -batch_size = 32 paddle.seed(2022) np.random.seed(2022) base_lr = 0.1 @@ -80,6 +79,7 @@ def train_mlp(model, use_pure_fp16=False, accumulate_grad=False, offload=False, + batch_size=100, convert2cpu=False): group = paddle.distributed.new_group([0, 1]) optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) @@ -91,7 +91,11 @@ def train_mlp(model, scaler = ShardingScaler(scaler) model = ShardingStage3( - model, optimizer=optimizer, group=group, offload=offload) + model, + optimizer=optimizer, + group=group, + offload=offload, + accumulate_grads=accumulate_grad) train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True) @@ -115,10 +119,15 @@ def train_mlp(model, loss = paddle.nn.functional.cross_entropy( input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + + if accumulate_grad: + avg_loss = avg_loss / 5 + if not use_pure_fp16: avg_loss.backward() else: scaler.scale(avg_loss).backward() + if not accumulate_grad: if not use_pure_fp16: optimizer.step() @@ -172,12 +181,14 @@ def test_stage3_offload(): atol=1e-2) # fp32 accumulate grad offload - stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True) + stage3_params = train_mlp( + mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True) stage3_params_offload = train_mlp( mlp6, use_pure_fp16=False, accumulate_grad=True, offload=True, + batch_size=20, convert2cpu=True) for i in range(len(stage3_params)): np.testing.assert_allclose(