未验证 提交 1a4a1520 编写于 作者: W wuhuachaocoding 提交者: GitHub

support bfp16 for stage3 and offload. (#49931)

上级 05c9c0a5
......@@ -45,10 +45,9 @@ def _all_gather(tensor, buffer_size, group):
# CUDA alignment 256 bytes
alignment = {
"gpu": 256,
}
alignment = {"gpu": 256, "cpu": 4096, "xpu": 256}
align = {
Type.bf16.value: 2,
Type.fp16.value: 2,
Type.fp32.value: 4,
}
......@@ -251,6 +250,11 @@ class GroupShardedStage3(nn.Layer):
and param2dtype[param.name] == Type.fp16.value
):
tmp_var = paddle.cast(tmp_var, Type.fp16.value)
elif (
tmp_var.dtype == Type.fp32.value
and param2dtype[param.name] == Type.bf16.value
):
tmp_var = paddle.cast(tmp_var, Type.bf16.value)
tmp_var._share_buffer_to(param)
del tmp_var
for grad_storage in self._grad_storages.values():
......@@ -312,11 +316,14 @@ class GroupShardedStage3(nn.Layer):
def _handle_unslice_params(self):
buffer_size = dict()
buffer_size[Type.bf16.value] = 0
buffer_size[Type.fp32.value] = 0
buffer_size[Type.fp16.value] = 0
for param in self._unslice_params:
# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
if (
param.dtype == Type.fp16.value or param.dtype == Type.bf16.value
) and not self._offload:
master_tensor = paddle.cast(param, Type.fp32.value)
master_tensor.name = param.name
self._optim._master_weights[param.name] = master_tensor
......@@ -419,10 +426,14 @@ class GroupShardedStage3(nn.Layer):
assert isinstance(buffer_size, int)
value = (
np.zeros(buffer_size, dtype=np.float16)
if Type.fp16.value == param.dtype
if (
Type.fp16.value == param.dtype or Type.bf16.value == param.dtype
)
else np.zeros(buffer_size, dtype=np.float32)
)
buffer = core.eager.Tensor(value=value, place=core.CPUPlace())
if Type.bf16.value == param.dtype:
buffer = buffer.cast(Type.bf16.value)
param_shape = param.shape
origin_state = param.stop_gradient
......@@ -462,7 +473,9 @@ class GroupShardedStage3(nn.Layer):
# Updata optimizer master weights
if (
param.trainable
and param.dtype == Type.fp16.value
and (
param.dtype == Type.fp16.value or param.dtype == Type.bf16.value
)
and not self._offload
):
master_tensor = paddle.cast(param.fw_storage, Type.fp32.value)
......@@ -1088,6 +1101,11 @@ def _cpu2device(param):
and param2dtype[param.name] == Type.fp16.value
):
tmp_p = paddle.cast(tmp_p, Type.fp16.value)
elif (
tmp_p.dtype == Type.fp32.value
and param2dtype[param.name] == Type.bf16.value
):
tmp_p = paddle.cast(tmp_p, Type.bf16.value)
return tmp_p
......
......@@ -54,8 +54,12 @@ class GroupShardedClipGrad:
@paddle.autograd.no_grad()
def _dygraph_clip(self, params_grads):
sum_square_fp32, sum_square_fp16 = [], []
unslice_params_fp32, unslice_params_fp16 = [], []
sum_square_fp32, sum_square_fp16, sum_square_bfp16 = [], [], []
unslice_params_fp32, unslice_params_fp16, unslice_params_bfp16 = (
[],
[],
[],
)
for p, g in params_grads:
p_slice = True # using for slice parameter in sharding stage3
......@@ -82,6 +86,11 @@ class GroupShardedClipGrad:
sum_square_fp32.append(sum_square)
else:
unslice_params_fp32.append(sum_square)
elif p.dtype == paddle.bfloat16:
if p_slice:
sum_square_bfp16.append(sum_square)
else:
unslice_params_bfp16.append(sum_square)
# global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0:
......@@ -93,6 +102,16 @@ class GroupShardedClipGrad:
global_norm_fp16, dtype=paddle.float32
)
# global norm of non-distributed BFP16 params_and_grads
if len(sum_square_bfp16) == 0:
global_norm_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
else:
global_norm_bfp16 = paddle.concat(sum_square_bfp16)
global_norm_bfp16 = paddle.sum(global_norm_bfp16)
global_norm_bfp16 = paddle.cast(
global_norm_bfp16, dtype=paddle.float32
)
# global norm of non-distributed FP16 params_and_grads for unslice parameters
if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
......@@ -103,6 +122,16 @@ class GroupShardedClipGrad:
global_unslice_fp16, dtype=paddle.float32
)
# global norm of non-distributed BFP16 params_and_grads for unslice parameters
if len(unslice_params_bfp16) == 0:
global_unslice_bfp16 = paddle.to_tensor([0.0], dtype=paddle.float32)
else:
global_unslice_bfp16 = paddle.concat(unslice_params_bfp16)
global_unslice_bfp16 = paddle.sum(global_unslice_bfp16)
global_unslice_bfp16 = paddle.cast(
global_unslice_bfp16, dtype=paddle.float32
)
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = (
paddle.concat(sum_square_fp32)
......@@ -118,9 +147,13 @@ class GroupShardedClipGrad:
else paddle.to_tensor([0.0], dtype=paddle.float32)
)
global_unslice_fp32 = paddle.sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_unslice_var = (
global_unslice_fp16 + global_unslice_fp32 + global_unslice_bfp16
)
global_norm_var = global_norm_fp16 + global_norm_fp32
global_norm_var = (
global_norm_fp16 + global_norm_fp32 + global_norm_bfp16
)
# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
......@@ -181,6 +214,7 @@ def GroupShardedScaler(scaler):
if not self._enable:
return
param_grads = []
param_grads_bfp16 = []
param_grads_fp16 = []
param_grads_fp32 = []
if hasattr(optimizer, "update_slice"):
......@@ -200,6 +234,8 @@ def GroupShardedScaler(scaler):
paddle.float16,
]:
param_grads_fp16.append(param.grad)
elif param.grad.dtype in [paddle.bfloat16]:
param_grads_bfp16.append(param.grad)
else:
param_grads_fp32.append(param.grad)
else:
......@@ -211,10 +247,13 @@ def GroupShardedScaler(scaler):
paddle.float16,
]:
param_grads_fp16.append(param.grad)
elif param.grad.dtype in [paddle.bfloat16]:
param_grads_bfp16.append(param.grad)
else:
param_grads_fp32.append(param.grad)
temp_found_inf_fp16 = to_variable(np.array([0]).astype(np.bool_))
temp_found_inf_bfp16 = to_variable(np.array([0]).astype(np.bool_))
temp_found_inf_fp32 = to_variable(np.array([0]).astype(np.bool_))
device = paddle.get_device().split(":")[0]
......@@ -224,6 +263,16 @@ def GroupShardedScaler(scaler):
)
with device_guard(dev_id, device):
if len(param_grads_bfp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_bfp16,
self._scale,
param_grads_bfp16,
temp_found_inf_bfp16,
)
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_bfp16
)
if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp16,
......
......@@ -16,6 +16,7 @@
import os
import shutil
import subprocess
import tempfile
import numpy as np
......@@ -135,6 +136,7 @@ def train_mlp(
model,
sharding_stage,
use_pure_fp16=False,
use_bfp16=False,
accumulate_grad=False,
batch_size=100,
opt_group=False,
......@@ -154,7 +156,10 @@ def train_mlp(
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32'
models=model,
level='O2',
save_dtype='float32',
dtype='bfloat16' if use_bfp16 else 'float16',
)
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = GroupShardedScaler(scaler)
......@@ -201,7 +206,11 @@ def train_mlp(
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(use_pure_fp16, level='O2'):
with paddle.amp.auto_cast(
use_pure_fp16,
level='O2',
dtype='bfloat16' if use_bfp16 else 'float16',
):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label
......@@ -240,7 +249,23 @@ def train_mlp(
def test_stage2_stage3():
paddle.distributed.init_parallel_env()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = (
(
mlp,
mlp1,
mlp2,
mlp3,
mlp4,
mlp5,
mlp6,
mlp7,
mlp8,
mlp9,
mlp10,
mlp11,
mlp12,
) = (
MLP(),
MLP(),
MLP(),
MLP(),
MLP(),
......@@ -264,6 +289,8 @@ def test_stage2_stage3():
mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict)
mlp10.set_state_dict(state_dict)
mlp11.set_state_dict(state_dict)
mlp12.set_state_dict(state_dict)
# fp32
stage2_params = train_mlp(
......@@ -336,6 +363,41 @@ def test_stage2_stage3():
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6
)
# bfp16
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
stage2_params = train_mlp(
mlp11,
sharding_stage=2,
use_pure_fp16=True,
opt_group=False,
use_bfp16=True,
)
stage3_params = train_mlp(
mlp12,
sharding_stage=3,
use_pure_fp16=True,
opt_group=False,
use_bfp16=True,
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-3,
)
# test for share layer parameters and exclude_layer function.
sm1, sm2, sm3, sm4 = (
SpecialModel(),
......
......@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import numpy as np
import paddle
......@@ -84,6 +86,7 @@ def optimizer_setting(model, use_pure_fp16, opt_group=False):
def train_mlp(
model,
use_pure_fp16=False,
use_bfp16=False,
accumulate_grad=False,
offload=False,
batch_size=100,
......@@ -94,7 +97,10 @@ def train_mlp(
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32'
models=model,
level='O2',
save_dtype='float32',
dtype='bfloat16' if use_bfp16 else 'float16',
)
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = GroupShardedScaler(scaler)
......@@ -123,7 +129,11 @@ def train_mlp(
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(use_pure_fp16, level='O2'):
with paddle.amp.auto_cast(
use_pure_fp16,
level='O2',
dtype='bfloat16' if use_bfp16 else 'float16',
):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label
......@@ -161,7 +171,9 @@ def train_mlp(
def test_stage3_offload():
paddle.distributed.init_parallel_env()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = (
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8 = (
MLP(),
MLP(),
MLP(),
MLP(),
MLP(),
......@@ -177,6 +189,8 @@ def test_stage3_offload():
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
# fp32 offload
stage3_params = train_mlp(mlp1, use_pure_fp16=False)
......@@ -200,6 +214,31 @@ def test_stage3_offload():
atol=1e-2,
)
# bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True)
stage3_params_offload = train_mlp(
mlp8, use_pure_fp16=True, offload=True, use_bfp16=True
)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-2,
atol=1e-2,
)
# fp32 accumulate grad offload
stage3_params = train_mlp(
mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册