未验证 提交 23d559dd 编写于 作者: B Baibaifan 提交者: GitHub

optimize sharding stage3 (#39334)

上级 41eb2595
......@@ -35,6 +35,7 @@ from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad, device_guard
from ..pp_utils.utils import _all_gather
from ...utils.internal_storage import GradStorage
# CUDA alignment 256 bytes
alignment = {"gpu": 256, }
......@@ -69,6 +70,7 @@ class ShardingStage3(nn.Layer):
group=None,
sync_buffers=False,
device="gpu",
segment_size=2**15,
pertrain_sync_models=True,
accumulate_grads=False,
offload=False,
......@@ -83,6 +85,8 @@ class ShardingStage3(nn.Layer):
self._accumulate_grads = accumulate_grads
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
self._segment_size = segment_size if not offload else 0
global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
......@@ -107,7 +111,10 @@ class ShardingStage3(nn.Layer):
self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]}
self._trainable_params = dict() # {layer.name: [trainable_params]}
self._trainable_params = dict() # {id(layer): [trainable_params]}
self._unslice_params = set() # param's numel <= segment_size
self._unslice_params2align = dict() # {param.name: param's align}
self._grad_storages = dict() # {param.dtype: GradStorage}
assert not isinstance(
optimizer, list), "Multiple optimizers are not supported now."
......@@ -131,10 +138,13 @@ class ShardingStage3(nn.Layer):
self._segment_rank_params(self._layer)
# Add unslice params to master_weight in fp16
self._handle_unslice_params()
# In the first step, record the execution order of the layer
self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0
self._order_tracer["layer"] = []
self._order_tracer["layer"] = list()
# Register task flow
self._task_flow = TaskFlow()
......@@ -168,8 +178,10 @@ class ShardingStage3(nn.Layer):
def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
# 1.Handle param's slice
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
for param in trainable_params:
assert hasattr(
param, "fw_storage"
......@@ -178,6 +190,9 @@ class ShardingStage3(nn.Layer):
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
# Update param memery slice
def _update_params_slice(self):
......@@ -185,20 +200,25 @@ class ShardingStage3(nn.Layer):
if not isinstance(self._optim._param_groups[0], dict):
slice_params = [param.fw_storage for param in update_list]
self._optim._parameter_list = slice_params
self._optim._param_groups = slice_params
self._optim._parameter_list = slice_params + list(
self._unslice_params)
self._optim._param_groups = slice_params + list(
self._unslice_params)
else:
params_name_list = list(map(lambda p: p.name, update_list))
fw_storage_name_list = list(
map(lambda p: p.fw_storage.name, update_list))
for param_group in self._optim._param_groups:
slice_p = []
p_group = []
for p in param_group['params']:
if p.name in params_name_list:
assert hasattr(
p, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(
p.name)
slice_p.append(p.fw_storage)
param_group['params'] = slice_p
p_group.append(p.fw_storage)
elif p.name in fw_storage_name_list:
p_group.append(update_list[fw_storage_name_list.index(
p.name)].fw_storage)
elif p in self._unslice_params:
p_group.append(p)
param_group['params'] = p_group
def forward(self, *inputs, **kwargs):
"""
......@@ -213,6 +233,32 @@ class ShardingStage3(nn.Layer):
return fw
def _handle_unslice_params(self):
buffer_size = dict()
buffer_size[Type.fp32.value] = 0
buffer_size[Type.fp16.value] = 0
for param in self._unslice_params:
# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
param2dtype[param.name] = param.dtype
p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align
buffer_size[param.dtype] += param._numel() + p_align
# Create unslice_params'grad
for param in sorted(list(self._unslice_params), key=lambda p: p.name):
if param.dtype not in self._grad_storages.keys():
self._grad_storages[param.dtype] = GradStorage(
buffer_size[param.dtype],
dtype=param.dtype,
device=self._default_device,
destination=self._rank,
parm2align=self._unslice_params2align)
self._grad_storages[param.dtype].add_grad(
param, self._unslice_params2align[param.name])
def _segment_rank_params(self, layer, name="last_layer"):
"""
Flatten parameters according to layer.
......@@ -233,24 +279,22 @@ class ShardingStage3(nn.Layer):
def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
current_params = list()
for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size:
current_params.append(_add_manage_info(p))
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))
assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = list(
map(_add_manage_info, trainable_params))
self._trainable_params[id(layer)] = current_params
for param in self._trainable_params[id(layer)]:
if param.name in self._param2buffer.keys():
continue
self._param2buffer[param.name] = []
# 1.Params alignment
offset = 0
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
align_ = self._param2align(param)
offset = align_ + param._numel()
buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - (
......@@ -379,7 +423,9 @@ class ShardingStage3(nn.Layer):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
# 1.Handle param's slice
for param in trainable_params:
assert hasattr(
param,
......@@ -396,6 +442,19 @@ class ShardingStage3(nn.Layer):
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.scale_(scale=self._world_size_scaling)
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
return update_list
def get_all_parameters(self, convert2cpu=False):
......@@ -405,7 +464,8 @@ class ShardingStage3(nn.Layer):
assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
t_flow = _allgather_buffer(
trainable_params,
self._group,
......@@ -415,7 +475,7 @@ class ShardingStage3(nn.Layer):
offload=self._offload,
convert2cpu=convert2cpu)
if convert2cpu:
for param in current_layer_params:
for param in trainable_params:
t_flow.full_param[param.name]._share_buffer_to(param)
self._optim._parameter_list = self._ori_parameter_list
......@@ -424,7 +484,8 @@ class ShardingStage3(nn.Layer):
def _register_backward_hooks(self):
current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list(
filter(lambda x: x.trainable, current_layer_params))
filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
for param in trainable_params:
allreduce_function = self._get_allreduce_fn(param)
......@@ -435,42 +496,36 @@ class ShardingStage3(nn.Layer):
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
with paddle.amp.auto_cast(enable=False):
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad, group=self._group, use_calc_stream=True)
dist.wait(
tensor=full_grad, group=self._group, use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage,
True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage, True)
else:
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(
start, end)).detach().clone())
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(start, end)).detach(
).clone())
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name)
......@@ -493,6 +548,15 @@ class ShardingStage3(nn.Layer):
return reduce
def _param2align(self, param):
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
return align_
def _redefine_opt_step(self):
params_slice_func = self._update_params_slice
opt_step = self._optim.step
......@@ -679,14 +743,13 @@ def _wait_layer(trainable_params,
group,
use_calc_stream,
offload=False):
paddle.device.cuda.synchronize()
for param in trainable_params:
if param.status == "all":
param.use_count += 1
continue
if param.name in task_flow.full_param.keys():
full_param = task_flow.full_param[param.name]
with paddle.amp.auto_cast(enable=False):
paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.fw_storage._clear()
......@@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params,
full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream)
# Allgather current layer in the 1st step
# Allgather current layer in the 1st step synchronously
if sync_wait:
with paddle.amp.auto_cast(enable=False):
dist.wait(
......@@ -774,6 +837,12 @@ def _PartitionParam(param):
return param
def _UnsliceParam(param):
if not hasattr(param, "unslice"):
setattr(param, "unslice", True)
return param
def _VarBaseWrapper(param):
varbase = param.fw_storage
tmp_param = ParamBase(
......
......@@ -57,12 +57,15 @@ class ShardingClipGrad:
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
sum_square_fp16 = []
sum_square_fp32 = []
sum_square_fp32, sum_square_fp16 = [], []
unslice_params_fp32, unslice_params_fp16 = [], []
for p, g in params_grads:
p_slice = True # using for slice parameter in sharding stage3
if g is None or getattr(p, 'need_clip', True) is False:
continue
if hasattr(p, "unslice"):
p_slice = False
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
......@@ -72,9 +75,11 @@ class ShardingClipGrad:
sum_square = layers.reduce_sum(square)
if p.dtype == paddle.float16:
sum_square_fp16.append(sum_square)
if p_slice: sum_square_fp16.append(sum_square)
else: unslice_params_fp16.append(sum_square)
elif p.dtype == paddle.float32:
sum_square_fp32.append(sum_square)
if p_slice: sum_square_fp32.append(sum_square)
else: unslice_params_fp32.append(sum_square)
# global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0:
......@@ -85,12 +90,28 @@ class ShardingClipGrad:
global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32)
# global norm of non-distributed FP16 params_and_grads for slice parameter
if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
global_unslice_fp16 = layers.concat(unslice_params_fp16)
global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
global_unslice_fp16 = paddle.cast(
global_unslice_fp16, dtype=paddle.float32)
# global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = layers.concat(sum_square_fp32) if len(
sum_square_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32)
# global norm of non-distributed FP32 params_and_grads for slice parameter
global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
unslice_params_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32
# add all reduce to get global norm of distributed params_and_grads
......@@ -98,6 +119,7 @@ class ShardingClipGrad:
with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var += global_unslice_var
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
......
......@@ -145,6 +145,10 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
if not use_pure_fp16:
avg_loss.backward()
else:
......@@ -215,7 +219,7 @@ def test_stage2_stage3():
stage3_params[i].numpy(),
stage3_params_add[i].numpy(),
rtol=1e-6,
atol=1e-6)
atol=1e-4)
# fp16
stage2_params = train_mlp(
......
......@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2022)
np.random.seed(2022)
base_lr = 0.1
......@@ -80,6 +79,7 @@ def train_mlp(model,
use_pure_fp16=False,
accumulate_grad=False,
offload=False,
batch_size=100,
convert2cpu=False):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
......@@ -91,7 +91,11 @@ def train_mlp(model,
scaler = ShardingScaler(scaler)
model = ShardingStage3(
model, optimizer=optimizer, group=group, offload=offload)
model,
optimizer=optimizer,
group=group,
offload=offload,
accumulate_grads=accumulate_grad)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
......@@ -115,10 +119,15 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if accumulate_grad:
avg_loss = avg_loss / 5
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
optimizer.step()
......@@ -172,12 +181,14 @@ def test_stage3_offload():
atol=1e-2)
# fp32 accumulate grad offload
stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True)
stage3_params = train_mlp(
mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True)
stage3_params_offload = train_mlp(
mlp6,
use_pure_fp16=False,
accumulate_grad=True,
offload=True,
batch_size=20,
convert2cpu=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册