未验证 提交 1a4a1520 编写于 作者: W wuhuachaocoding 提交者: GitHub

support bfp16 for stage3 and offload. (#49931)

上级 05c9c0a5
...@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group): ...@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
# CUDA alignment 256 bytes # CUDA alignment 256 bytes
alignment = { alignment = {"gpu": 256, "cpu": 4096, "xpu": 256}
"gpu": 256,
}
align = { align = {
Type.bf16.value: 2,
Type.fp16.value: 2, Type.fp16.value: 2,
Type.fp32.value: 4, Type.fp32.value: 4,
} }
...@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer): ...@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
and param2dtype[param.name] == Type.fp16.value and param2dtype[param.name] == Type.fp16.value
): ):
tmp_var = paddle.cast(tmp_var, Type.fp16.value) tmp_var = paddle.cast(tmp_var, Type.fp16.value)
elif (
tmp_var.dtype == Type.fp32.value
and param2dtype[param.name] == Type.bf16.value
):
tmp_var = paddle.cast(tmp_var, Type.bf16.value)
tmp_var._share_buffer_to(param) tmp_var._share_buffer_to(param)
del tmp_var del tmp_var
for grad_storage in self._grad_storages.values(): for grad_storage in self._grad_storages.values():
...@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer): ...@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
def _handle_unslice_params(self): def _handle_unslice_params(self):
buffer_size = dict() buffer_size = dict()
buffer_size[Type.bf16.value] = 0
buffer_size[Type.fp32.value] = 0 buffer_size[Type.fp32.value] = 0
buffer_size[Type.fp16.value] = 0 buffer_size[Type.fp16.value] = 0
for param in self._unslice_params: for param in self._unslice_params:
# Updata optimizer master weights # Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload: if (
param.dtype == Type.fp16.value or param.dtype == Type.bf16.value
) and not self._offload:
master_tensor = paddle.cast(param, Type.fp32.value) master_tensor = paddle.cast(param, Type.fp32.value)
master_tensor.name = param.name master_tensor.name = param.name
self._optim._master_weights[param.name] = master_tensor self._optim._master_weights[param.name] = master_tensor
...@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer): ...@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
assert isinstance(buffer_size, int) assert isinstance(buffer_size, int)
value = ( value = (
np.zeros(buffer_size, dtype=np.float16) np.zeros(buffer_size, dtype=np.float16)
if Type.fp16.value == param.dtype if (
Type.fp16.value == param.dtype or Type.bf16.value == param.dtype
)
else np.zeros(buffer_size, dtype=np.float32) else np.zeros(buffer_size, dtype=np.float32)
) )
buffer = core.eager.Tensor(value=value, place=core.CPUPlace()) buffer = core.eager.Tensor(value=value, place=core.CPUPlace())
if Type.bf16.value == param.dtype:
buffer = buffer.cast(Type.bf16.value)
param_shape = param.shape param_shape = param.shape
origin_state = param.stop_gradient origin_state = param.stop_gradient
...@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer): ...@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
# Updata optimizer master weights # Updata optimizer master weights
if ( if (
param.trainable param.trainable
and param.dtype == Type.fp16.value and (
param.dtype == Type.fp16.value or param.dtype == Type.bf16.value
)
and not self._offload and not self._offload
): ):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value) master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
...@@ -1088,6 +1101,11 @@ def _cpu2device(param): ...@@ -1088,6 +1101,11 @@ def _cpu2device(param):
and param2dtype[param.name] == Type.fp16.value and param2dtype[param.name] == Type.fp16.value
): ):
tmp_p = paddle.cast(tmp_p, Type.fp16.value) tmp_p = paddle.cast(tmp_p, Type.fp16.value)
elif (
tmp_p.dtype == Type.fp32.value
and param2dtype[param.name] == Type.bf16.value
):
tmp_p = paddle.cast(tmp_p, Type.bf16.value)
return tmp_p return tmp_p
......
...@@ -54,8 +54,12 @@ class GroupShardedClipGrad: ...@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
sum_square_fp32, sum_square_fp16 = [], [] sum_square_fp32, sum_square_fp16, sum_square_bfp16 = [], [], []
unslice_params_fp32, unslice_params_fp16 = [], [] unslice_params_fp32, unslice_params_fp16, unslice_params_bfp16 = (
[],
[],
[],
)
for p, g in params_grads: for p, g in params_grads:
p_slice = True # using for slice parameter in sharding stage3 p_slice = True # using for slice parameter in sharding stage3
...@@ -82,6 +86,11 @@ class GroupShardedClipGrad: ...@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
sum_square_fp32.append(sum_square) sum_square_fp32.append(sum_square)
else: else:
unslice_params_fp32.append(sum_square) unslice_params_fp32.append(sum_square)
elif p.dtype == paddle.bfloat16:
if p_slice:
sum_square_bfp16.append(sum_square)
else:
unslice_params_bfp16.append(sum_square)
# global norm of non-distributed FP16 params_and_grads # global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0: if len(sum_square_fp16) == 0:
...@@ -93,6 +102,16 @@ class GroupShardedClipGrad: ...@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
global_norm_fp16, dtype=paddle.float32 global_norm_fp16, dtype=paddle.float32
) )
# global norm of non-distributed BFP16 params_and_grads
if len(sum_square_bfp16) == 0:
global_norm_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
else:
global_norm_bfp16 = paddle.concat(sum_square_bfp16)
global_norm_bfp16 = paddle.sum(global_norm_bfp16)
global_norm_bfp16 = paddle.cast(
global_norm_bfp16, dtype=paddle.float32
)
# global norm of non-distributed FP16 params_and_grads for unslice parameters # global norm of non-distributed FP16 params_and_grads for unslice parameters
if len(unslice_params_fp16) == 0: if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32) global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
...@@ -103,6 +122,16 @@ class GroupShardedClipGrad: ...@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
global_unslice_fp16, dtype=paddle.float32 global_unslice_fp16, dtype=paddle.float32
) )
# global norm of non-distributed BFP16 params_and_grads for unslice parameters
if len(unslice_params_bfp16) == 0:
global_unslice_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
else:
global_unslice_bfp16 = paddle.concat(unslice_params_bfp16)
global_unslice_bfp16 = paddle.sum(global_unslice_bfp16)
global_unslice_bfp16 = paddle.cast(
global_unslice_bfp16, dtype=paddle.float32
)
# global norm of non-distributed FP32 params_and_grads # global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = ( global_norm_fp32 = (
paddle.concat(sum_square_fp32) paddle.concat(sum_square_fp32)
...@@ -118,9 +147,13 @@ class GroupShardedClipGrad: ...@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
else paddle.to_tensor([0.0], dtype=paddle.float32) else paddle.to_tensor([0.0], dtype=paddle.float32)
) )
global_unslice_fp32 = paddle.sum(global_unslice_fp32) global_unslice_fp32 = paddle.sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32 global_unslice_var = (
global_unslice_fp16 + global_unslice_fp32 + global_unslice_bfp16
)
global_norm_var = global_norm_fp16 + global_norm_fp32 global_norm_var = (
global_norm_fp16 + global_norm_fp32 + global_norm_bfp16
)
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1]) dev_id = int(self._device.split(":")[1])
...@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler): ...@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
if not self._enable: if not self._enable:
return return
param_grads = [] param_grads = []
param_grads_bfp16 = []
param_grads_fp16 = [] param_grads_fp16 = []
param_grads_fp32 = [] param_grads_fp32 = []
if hasattr(optimizer, "update_slice"): if hasattr(optimizer, "update_slice"):
...@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler): ...@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
paddle.float16, paddle.float16,
]: ]:
param_grads_fp16.append(param.grad) param_grads_fp16.append(param.grad)
elif param.grad.dtype in [paddle.bfloat16]:
param_grads_bfp16.append(param.grad)
else: else:
param_grads_fp32.append(param.grad) param_grads_fp32.append(param.grad)
else: else:
...@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler): ...@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
paddle.float16, paddle.float16,
]: ]:
param_grads_fp16.append(param.grad) param_grads_fp16.append(param.grad)
elif param.grad.dtype in [paddle.bfloat16]:
param_grads_bfp16.append(param.grad)
else: else:
param_grads_fp32.append(param.grad) param_grads_fp32.append(param.grad)
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
temp_found_inf_bfp16 = to_variable(np.array([0]).astype(np.bool_))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_)) temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
device = paddle.get_device().split(":")[0] device = paddle.get_device().split(":")[0]
...@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler): ...@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
) )
with device_guard(dev_id, device): with device_guard(dev_id, device):
if len(param_grads_bfp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bfp16,
self._scale,
param_grads_bfp16,
temp_found_inf_bfp16,
)
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_bfp16
)
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp16, param_grads_fp16,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os import os
import shutil import shutil
import subprocess
import tempfile import tempfile
import numpy as np import numpy as np
...@@ -135,6 +136,7 @@ def train_mlp( ...@@ -135,6 +136,7 @@ def train_mlp(
model, model,
sharding_stage, sharding_stage,
use_pure_fp16=False, use_pure_fp16=False,
use_bfp16=False,
accumulate_grad=False, accumulate_grad=False,
batch_size=100, batch_size=100,
opt_group=False, opt_group=False,
...@@ -154,7 +156,10 @@ def train_mlp( ...@@ -154,7 +156,10 @@ def train_mlp(
if use_pure_fp16: if use_pure_fp16:
model = paddle.amp.decorate( model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32' models=model,
level='O2',
save_dtype='float32',
dtype='bfloat16' if use_bfp16 else 'float16',
) )
scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = GroupShardedScaler(scaler) scaler = GroupShardedScaler(scaler)
...@@ -201,7 +206,11 @@ def train_mlp( ...@@ -201,7 +206,11 @@ def train_mlp(
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(use_pure_fp16, level='O2'): with paddle.amp.auto_cast(
use_pure_fp16,
level='O2',
dtype='bfloat16' if use_bfp16 else 'float16',
):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label input=out, label=label
...@@ -240,7 +249,23 @@ def train_mlp( ...@@ -240,7 +249,23 @@ def train_mlp(
def test_stage2_stage3(): def test_stage2_stage3():
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = ( (
mlp,
mlp1,
mlp2,
mlp3,
mlp4,
mlp5,
mlp6,
mlp7,
mlp8,
mlp9,
mlp10,
mlp11,
mlp12,
) = (
MLP(),
MLP(),
MLP(), MLP(),
MLP(), MLP(),
MLP(), MLP(),
...@@ -264,6 +289,8 @@ def test_stage2_stage3(): ...@@ -264,6 +289,8 @@ def test_stage2_stage3():
mlp8.set_state_dict(state_dict) mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict) mlp9.set_state_dict(state_dict)
mlp10.set_state_dict(state_dict) mlp10.set_state_dict(state_dict)
mlp11.set_state_dict(state_dict)
mlp12.set_state_dict(state_dict)
# fp32 # fp32
stage2_params = train_mlp( stage2_params = train_mlp(
...@@ -336,6 +363,41 @@ def test_stage2_stage3(): ...@@ -336,6 +363,41 @@ def test_stage2_stage3():
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6 stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6
) )
# bfp16
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
stage2_params = train_mlp(
mlp11,
sharding_stage=2,
use_pure_fp16=True,
opt_group=False,
use_bfp16=True,
)
stage3_params = train_mlp(
mlp12,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
use_bfp16=True,
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3,
)
# test for share layer parameters and exclude_layer function. # test for share layer parameters and exclude_layer function.
sm1, sm2, sm3, sm4 = ( sm1, sm2, sm3, sm4 = (
SpecialModel(), SpecialModel(),
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import subprocess
import numpy as np import numpy as np
import paddle import paddle
...@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False): ...@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
def train_mlp( def train_mlp(
model, model,
use_pure_fp16=False, use_pure_fp16=False,
use_bfp16=False,
accumulate_grad=False, accumulate_grad=False,
offload=False, offload=False,
batch_size=100, batch_size=100,
...@@ -94,7 +97,10 @@ def train_mlp( ...@@ -94,7 +97,10 @@ def train_mlp(
if use_pure_fp16: if use_pure_fp16:
model = paddle.amp.decorate( model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32' models=model,
level='O2',
save_dtype='float32',
dtype='bfloat16' if use_bfp16 else 'float16',
) )
scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = GroupShardedScaler(scaler) scaler = GroupShardedScaler(scaler)
...@@ -123,7 +129,11 @@ def train_mlp( ...@@ -123,7 +129,11 @@ def train_mlp(
img, label = data img, label = data
label.stop_gradient = True label.stop_gradient = True
img.stop_gradient = True img.stop_gradient = True
with paddle.amp.auto_cast(use_pure_fp16, level='O2'): with paddle.amp.auto_cast(
use_pure_fp16,
level='O2',
dtype='bfloat16' if use_bfp16 else 'float16',
):
out = model(img) out = model(img)
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label input=out, label=label
...@@ -161,7 +171,9 @@ def train_mlp( ...@@ -161,7 +171,9 @@ def train_mlp(
def test_stage3_offload(): def test_stage3_offload():
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = ( mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = (
MLP(),
MLP(),
MLP(), MLP(),
MLP(), MLP(),
MLP(), MLP(),
...@@ -177,6 +189,8 @@ def test_stage3_offload(): ...@@ -177,6 +189,8 @@ def test_stage3_offload():
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict) mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict) mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
# fp32 offload # fp32 offload
stage3_params = train_mlp(mlp1, use_pure_fp16=False) stage3_params = train_mlp(mlp1, use_pure_fp16=False)
...@@ -200,6 +214,31 @@ def test_stage3_offload(): ...@@ -200,6 +214,31 @@ def test_stage3_offload():
atol=1e-2, atol=1e-2,
) )
# bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True)
stage3_params_offload = train_mlp(
mlp8, use_pure_fp16=True, offload=True, use_bfp16=True
)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-2,
atol=1e-2,
)
# fp32 accumulate grad offload # fp32 accumulate grad offload
stage3_params = train_mlp( stage3_params = train_mlp(
mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册