diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py index 768953eed072e6dc744d09a5c59aadb25821cd6b..5fb2e9a58d501df3819e90d1eafc2cb1e8a531c6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_stage3.py @@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group): # CUDA alignment 256 bytes -alignment = { - "gpu": 256, -} +alignment = {"gpu": 256, "cpu": 4096, "xpu": 256} align = { + Type.bf16.value: 2, Type.fp16.value: 2, Type.fp32.value: 4, } @@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer): and param2dtype[param.name] == Type.fp16.value ): tmp_var = paddle.cast(tmp_var, Type.fp16.value) + elif ( + tmp_var.dtype == Type.fp32.value + and param2dtype[param.name] == Type.bf16.value + ): + tmp_var = paddle.cast(tmp_var, Type.bf16.value) tmp_var._share_buffer_to(param) del tmp_var for grad_storage in self._grad_storages.values(): @@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer): def _handle_unslice_params(self): buffer_size = dict() + buffer_size[Type.bf16.value] = 0 buffer_size[Type.fp32.value] = 0 buffer_size[Type.fp16.value] = 0 for param in self._unslice_params: # Updata optimizer master weights - if param.dtype == Type.fp16.value and not self._offload: + if ( + param.dtype == Type.fp16.value or param.dtype == Type.bf16.value + ) and not self._offload: master_tensor = paddle.cast(param, Type.fp32.value) master_tensor.name = param.name self._optim._master_weights[param.name] = master_tensor @@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer): assert isinstance(buffer_size, int) value = ( np.zeros(buffer_size, dtype=np.float16) - if Type.fp16.value == param.dtype + if ( + Type.fp16.value == param.dtype or Type.bf16.value == param.dtype + ) else np.zeros(buffer_size, dtype=np.float32) ) buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) + if Type.bf16.value == param.dtype: + buffer = buffer.cast(Type.bf16.value) param_shape = param.shape origin_state = param.stop_gradient @@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer): # Updata optimizer master weights if ( param.trainable - and param.dtype == Type.fp16.value + and ( + param.dtype == Type.fp16.value or param.dtype == Type.bf16.value + ) and not self._offload ): master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) @@ -1088,6 +1101,11 @@ def _cpu2device(param): and param2dtype[param.name] == Type.fp16.value ): tmp_p = paddle.cast(tmp_p, Type.fp16.value) + elif ( + tmp_p.dtype == Type.fp32.value + and param2dtype[param.name] == Type.bf16.value + ): + tmp_p = paddle.cast(tmp_p, Type.bf16.value) return tmp_p diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py index 361b421bbae4be7e1bf9c38b7228cfd677f946e6..7a8f78d03334ef64755301b12d066fbf037acbb3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_utils.py @@ -54,8 +54,12 @@ class GroupShardedClipGrad: @paddle.autograd.no_grad() def _dygraph_clip(self, params_grads): - sum_square_fp32, sum_square_fp16 = [], [] - unslice_params_fp32, unslice_params_fp16 = [], [] + sum_square_fp32, sum_square_fp16, sum_square_bfp16 = [], [], [] + unslice_params_fp32, unslice_params_fp16, unslice_params_bfp16 = ( + [], + [], + [], + ) for p, g in params_grads: p_slice = True # using for slice parameter in sharding stage3 @@ -82,6 +86,11 @@ class GroupShardedClipGrad: sum_square_fp32.append(sum_square) else: unslice_params_fp32.append(sum_square) + elif p.dtype == paddle.bfloat16: + if p_slice: + sum_square_bfp16.append(sum_square) + else: + unslice_params_bfp16.append(sum_square) # global norm of non-distributed FP16 params_and_grads if len(sum_square_fp16) == 0: @@ -93,6 +102,16 @@ class GroupShardedClipGrad: global_norm_fp16, dtype=paddle.float32 ) + # global norm of non-distributed BFP16 params_and_grads + if len(sum_square_bfp16) == 0: + global_norm_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32) + else: + global_norm_bfp16 = paddle.concat(sum_square_bfp16) + global_norm_bfp16 = paddle.sum(global_norm_bfp16) + global_norm_bfp16 = paddle.cast( + global_norm_bfp16, dtype=paddle.float32 + ) + # global norm of non-distributed FP16 params_and_grads for unslice parameters if len(unslice_params_fp16) == 0: global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32) @@ -103,6 +122,16 @@ class GroupShardedClipGrad: global_unslice_fp16, dtype=paddle.float32 ) + # global norm of non-distributed BFP16 params_and_grads for unslice parameters + if len(unslice_params_bfp16) == 0: + global_unslice_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32) + else: + global_unslice_bfp16 = paddle.concat(unslice_params_bfp16) + global_unslice_bfp16 = paddle.sum(global_unslice_bfp16) + global_unslice_bfp16 = paddle.cast( + global_unslice_bfp16, dtype=paddle.float32 + ) + # global norm of non-distributed FP32 params_and_grads global_norm_fp32 = ( paddle.concat(sum_square_fp32) @@ -118,9 +147,13 @@ class GroupShardedClipGrad: else paddle.to_tensor([0.0], dtype=paddle.float32) ) global_unslice_fp32 = paddle.sum(global_unslice_fp32) - global_unslice_var = global_unslice_fp16 + global_unslice_fp32 + global_unslice_var = ( + global_unslice_fp16 + global_unslice_fp32 + global_unslice_bfp16 + ) - global_norm_var = global_norm_fp16 + global_norm_fp32 + global_norm_var = ( + global_norm_fp16 + global_norm_fp32 + global_norm_bfp16 + ) # add all reduce to get global norm of distributed params_and_grads dev_id = int(self._device.split(":")[1]) @@ -181,6 +214,7 @@ def GroupShardedScaler(scaler): if not self._enable: return param_grads = [] + param_grads_bfp16 = [] param_grads_fp16 = [] param_grads_fp32 = [] if hasattr(optimizer, "update_slice"): @@ -200,6 +234,8 @@ def GroupShardedScaler(scaler): paddle.float16, ]: param_grads_fp16.append(param.grad) + elif param.grad.dtype in [paddle.bfloat16]: + param_grads_bfp16.append(param.grad) else: param_grads_fp32.append(param.grad) else: @@ -211,10 +247,13 @@ def GroupShardedScaler(scaler): paddle.float16, ]: param_grads_fp16.append(param.grad) + elif param.grad.dtype in [paddle.bfloat16]: + param_grads_bfp16.append(param.grad) else: param_grads_fp32.append(param.grad) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) + temp_found_inf_bfp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) device = paddle.get_device().split(":")[0] @@ -224,6 +263,16 @@ def GroupShardedScaler(scaler): ) with device_guard(dev_id, device): + if len(param_grads_bfp16): + _legacy_C_ops.check_finite_and_unscale( + param_grads_bfp16, + self._scale, + param_grads_bfp16, + temp_found_inf_bfp16, + ) + self._found_inf = _C_ops.bitwise_or( + self._found_inf, temp_found_inf_bfp16 + ) if len(param_grads_fp16): _legacy_C_ops.check_finite_and_unscale( param_grads_fp16, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py index 1697655c8bad640a4f3b2682c875da1a3c98378a..f9344640e0e8357541b729b8daa914c400e1f491 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3.py @@ -16,6 +16,7 @@ import os import shutil +import subprocess import tempfile import numpy as np @@ -135,6 +136,7 @@ def train_mlp( model, sharding_stage, use_pure_fp16=False, + use_bfp16=False, accumulate_grad=False, batch_size=100, opt_group=False, @@ -154,7 +156,10 @@ def train_mlp( if use_pure_fp16: model = paddle.amp.decorate( - models=model, level='O2', save_dtype='float32' + models=model, + level='O2', + save_dtype='float32', + dtype='bfloat16' if use_bfp16 else 'float16', ) scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = GroupShardedScaler(scaler) @@ -201,7 +206,11 @@ def train_mlp( img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(use_pure_fp16, level='O2'): + with paddle.amp.auto_cast( + use_pure_fp16, + level='O2', + dtype='bfloat16' if use_bfp16 else 'float16', + ): out = model(img) loss = paddle.nn.functional.cross_entropy( input=out, label=label @@ -240,7 +249,23 @@ def train_mlp( def test_stage2_stage3(): paddle.distributed.init_parallel_env() - mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = ( + ( + mlp, + mlp1, + mlp2, + mlp3, + mlp4, + mlp5, + mlp6, + mlp7, + mlp8, + mlp9, + mlp10, + mlp11, + mlp12, + ) = ( + MLP(), + MLP(), MLP(), MLP(), MLP(), @@ -264,6 +289,8 @@ def test_stage2_stage3(): mlp8.set_state_dict(state_dict) mlp9.set_state_dict(state_dict) mlp10.set_state_dict(state_dict) + mlp11.set_state_dict(state_dict) + mlp12.set_state_dict(state_dict) # fp32 stage2_params = train_mlp( @@ -336,6 +363,41 @@ def test_stage2_stage3(): stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6 ) + # bfp16 + # NOTE: this is a hack to get int format nccl version, like 2134 + # if current platform is not linux, version number will be 0 + nccl_version_str = subprocess.check_output( + r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'", + stderr=subprocess.DEVNULL, + shell=True, + ).decode('utf-8') + nccl_version = ( + int("".join(nccl_version_str.split("."))) if nccl_version_str else 0 + ) + + if nccl_version >= 2100: + stage2_params = train_mlp( + mlp11, + sharding_stage=2, + use_pure_fp16=True, + opt_group=False, + use_bfp16=True, + ) + stage3_params = train_mlp( + mlp12, + sharding_stage=3, + use_pure_fp16=True, + opt_group=False, + use_bfp16=True, + ) + for i in range(len(stage2_params)): + np.testing.assert_allclose( + stage2_params[i].numpy(), + stage3_params[i].numpy(), + rtol=1e-4, + atol=1e-3, + ) + # test for share layer parameters and exclude_layer function. sm1, sm2, sm3, sm4 = ( SpecialModel(), diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py index ad9fb5c86d9e8649ee749f3db44dd0b72dcf73de..d2becd254544cd796e2dedffe024938c63ae9f8f 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/dygraph_group_sharded_stage3_offload.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import subprocess + import numpy as np import paddle @@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): def train_mlp( model, use_pure_fp16=False, + use_bfp16=False, accumulate_grad=False, offload=False, batch_size=100, @@ -94,7 +97,10 @@ def train_mlp( if use_pure_fp16: model = paddle.amp.decorate( - models=model, level='O2', save_dtype='float32' + models=model, + level='O2', + save_dtype='float32', + dtype='bfloat16' if use_bfp16 else 'float16', ) scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = GroupShardedScaler(scaler) @@ -123,7 +129,11 @@ def train_mlp( img, label = data label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(use_pure_fp16, level='O2'): + with paddle.amp.auto_cast( + use_pure_fp16, + level='O2', + dtype='bfloat16' if use_bfp16 else 'float16', + ): out = model(img) loss = paddle.nn.functional.cross_entropy( input=out, label=label @@ -161,7 +171,9 @@ def train_mlp( def test_stage3_offload(): paddle.distributed.init_parallel_env() - mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = ( + mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = ( + MLP(), + MLP(), MLP(), MLP(), MLP(), @@ -177,6 +189,8 @@ def test_stage3_offload(): mlp4.set_state_dict(state_dict) mlp5.set_state_dict(state_dict) mlp6.set_state_dict(state_dict) + mlp7.set_state_dict(state_dict) + mlp8.set_state_dict(state_dict) # fp32 offload stage3_params = train_mlp(mlp1, use_pure_fp16=False) @@ -200,6 +214,31 @@ def test_stage3_offload(): atol=1e-2, ) + # bfp16 offload + # NOTE: this is a hack to get int format nccl version, like 2134 + # if current platform is not linux, version number will be 0 + nccl_version_str = subprocess.check_output( + r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'", + stderr=subprocess.DEVNULL, + shell=True, + ).decode('utf-8') + nccl_version = ( + int("".join(nccl_version_str.split("."))) if nccl_version_str else 0 + ) + + if nccl_version >= 2100: + stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True) + stage3_params_offload = train_mlp( + mlp8, use_pure_fp16=True, offload=True, use_bfp16=True + ) + for i in range(len(stage3_params)): + np.testing.assert_allclose( + stage3_params[i].numpy(), + stage3_params_offload[i].numpy(), + rtol=1e-2, + atol=1e-2, + ) + # fp32 accumulate grad offload stage3_params = train_mlp( mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True