diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 9886ca4e2deace4c625ead51852841e7c761be21..f96273cc84caf46f4f02c62e648ce70445b52d28 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -912,7 +912,6 @@ def _device2cpu(trans_param, convert_dtype=False): def _cpu2device(param): tmp_p = param.fw_storage.cuda(DEV_ID) - param.fw_storage._clear() if tmp_p.dtype == Type.fp32.value and param2dtype[ param.name] == Type.fp16.value: tmp_p = paddle.cast(tmp_p, Type.fp16.value) diff --git a/python/paddle/distributed/sharding/group_sharded.py b/python/paddle/distributed/sharding/group_sharded.py index 2fdb20600f673b21e7cabd6ffe35c545b045bb5d..6fd4caa7b4a5c41e73fcf95ac50d0253bb3e7c79 100644 --- a/python/paddle/distributed/sharding/group_sharded.py +++ b/python/paddle/distributed/sharding/group_sharded.py @@ -39,19 +39,20 @@ def group_sharded_parallel(model, segment_size=2**20, sync_comm=False): """ - Use this module to configure and wrap up the parameters of the group shared module. + Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. + Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation. Args: model (Layer): The layer to be wrapped with group_sharded_parallel. optimizer (Optimizer): The optimizer to be wrapped with group_sharded_parallel. level (str): The different level of the group sharded. Such as `os`, `os_g`, `p_g_os`. - scaler (GradScaler, optional): The scaler to be wrapped with group_sharded_parallel. Defaults to None. - group (Group, optional): The group instance. Defaults to None.d - offload (bool, optional): Whether to perform optimizer state and gradient transfer CPU. Defaults to False. - sync_buffers (bool, optional): Whether to broadcast model buffers. Defaults to False. - buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. Defaults to 2**23. - segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20. - sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False. + scaler (GradScaler, optional): If AMP is used, you need to pass GradScaler. Defaults to None, indicating that GradScaler is not used. + group (Group, optional): The group instance. Defaults to None, indicating that the default environment group is used. + offload (bool, optional): Whether to use the offload function. Defaults to False, which means that the offload function is not used. + sync_buffers (bool, optional): Whether to broadcast model buffers. It is generally used when there are registered model buffers. Defaults to False, indicating that model buffers are not used. + buffer_max_size (int, optional): The max size of the buffer used to integrate gradient in `os_g`. The larger the size, the more GPU memory will be used. Defaults to 2**23, which means that the dimension of the buffer is 2**23. + segment_size (int, optional): The smallest size of parameter to be sharded in `p_g_os`. Defaults to 2**20, indicating that the dimension of the minimum segmented parameter is 2**20. + sync_comm (bool, optional): Whether to use synchronous communication, only in `p_g_os` used. Defaults to False, indicating that asynchronous communication is used. Returns: model: A wrapper for group sharded given model. @@ -101,7 +102,7 @@ def group_sharded_parallel(model, def check_dtype(param): return param.dtype == paddle.float16 - params_fp16 = filter(check_dtype, model.parameters()) + params_fp16 = list(filter(check_dtype, model.parameters())) if scaler is None and len(params_fp16) > 0: raise ValueError("Please enter the correct scaler.") # convert model/optimizer/scaler @@ -146,10 +147,13 @@ def save_group_sharded_model(model, output, optimizer=None): """ Group sharded encapsulated model and optimizer state saving module. + .. note:: + If using save_group_sharded_model saves the model. When loading again, you need to set the model or optimizer state before using group_sharded_parallel. + Args: model (Layer): A wrapper for group sharded given model. output (str): Save directory. - optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None. + optimizer (Optimizer, optional): Group sharded encapsulated optimizer. Defaults to None, indicating that the optimizer state is not saved. Examples: .. code-block:: python @@ -182,7 +186,7 @@ def save_group_sharded_model(model, output, optimizer=None): optimizer.clear_grad() # save model and optimizer state_dict - save_group_sharded_model(model, optimizer,output=output_dir) + save_group_sharded_model(model, optimizer, output=output_dir) """ logger_.info( "==========Begin to save group sharded model and optimizer==========")