未验证 提交 18c6f40b 编写于 作者: B Baibaifan 提交者: GitHub

optimizer sharding paramters (#39581)

上级 1f7f8561
......@@ -65,9 +65,9 @@ class ShardingOptimizerStage2(Optimizer):
params,
optim,
group=None,
broadcast_fp16=False,
offload=False,
device="gpu",
pertrain_sync_models=True,
**kw):
super().__init__(optim._learning_rate, params, kw)
......@@ -98,8 +98,12 @@ class ShardingOptimizerStage2(Optimizer):
self.world_size = self.group.nranks
self.rank = self.group.rank
self._global_root_rank = 0
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
self.broadcast_fp16 = broadcast_fp16
self.param_storages = {} # {dtype: {rank: InternalStorage}}
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
......@@ -132,6 +136,22 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status()
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for p in self._local_params:
dist.broadcast(
p,
src=self._global_root_rank,
group=self.group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=p, group=self.group, use_calc_stream=True)
def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
......
......@@ -61,12 +61,10 @@ class ShardingStage2(nn.Layer):
sharding_optimizer,
group=None,
sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB
auto_refresh_trainable=True,
device="gpu",
use_grad_storage=True,
accumulate_grads=False):
use_grad_storage=True):
super().__init__()
# training options
......@@ -81,9 +79,6 @@ class ShardingStage2(nn.Layer):
self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable
# Gradient accumulation, Gradient flip
self._accumulate_grads = accumulate_grads
# Communication related attributes
self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group
......@@ -128,16 +123,11 @@ class ShardingStage2(nn.Layer):
# Set backward pass hooks
self._bw_hooks = []
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
# Set tasks flow
self._tasks_flow = deque()
# Define optimizer step and clear_grad
if self._accumulate_grads:
self._redefine_opt_step()
self._redefine_opt_step()
self._redefine_opt_clear()
def forward(self, *inputs, **kwargs):
......@@ -313,9 +303,6 @@ class ShardingStage2(nn.Layer):
# Change reduce information
self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)
# Clear the gradient that does not belong to the current rank through the callback function
def cleanup():
......@@ -362,11 +349,6 @@ class ShardingStage2(nn.Layer):
if grad_storage.all_checked_in:
assert grad_storage.buffer is not None
# Normalize all ranks grad_storage
if not self._accumulate_grads:
grad_storage.buffer.scale_(
scale=self._world_size_scaling)
# Clearing up the grad_storage buffer
def cleanup():
if dst_rank != self._rank:
......@@ -432,22 +414,6 @@ class ShardingStage2(nn.Layer):
self._bw_hooks.append(
param._register_backward_hook(reduce_function))
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for t in self._layer.parameters():
dist.broadcast(
t,
src=self._global_root_rank,
group=self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=t, group=self._group, use_calc_stream=True)
def _setup_use_grad_storage(self):
"""
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
......@@ -555,8 +521,6 @@ class ShardingStage2(nn.Layer):
return rank_buffer_size
def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale
for opt in self._sharding_optimizers:
opt_step = opt.step
......
......@@ -72,7 +72,6 @@ class ShardingStage3(nn.Layer):
device="gpu",
segment_size=2**15,
pertrain_sync_models=True,
accumulate_grads=False,
offload=False,
sync_comm=False):
super().__init__()
......@@ -82,7 +81,6 @@ class ShardingStage3(nn.Layer):
self._layer = layer
self._default_device = device
self.__sync_buffers = sync_buffers
self._accumulate_grads = accumulate_grads
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
......@@ -190,6 +188,7 @@ class ShardingStage3(nn.Layer):
param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
param.bw_storage = None
# 2.Handle unslice param
if not self._offload:
for grad_storage in self._grad_storages.values():
......@@ -446,13 +445,12 @@ class ShardingStage3(nn.Layer):
param,
"fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name)
if self._accumulate_grads:
if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
# Gradient average
if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
......@@ -526,8 +524,6 @@ class ShardingStage3(nn.Layer):
def reduce(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now
dist.all_reduce(
tensor=full_grad, group=self._group, use_calc_stream=True)
......@@ -535,8 +531,7 @@ class ShardingStage3(nn.Layer):
tensor=full_grad, group=self._group, use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
if param.bw_storage is None:
param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone()
if self._offload:
......
......@@ -27,7 +27,7 @@ from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
seed = 2021
seed = 2022
epoch = 2
linear_size = 1000
......@@ -105,11 +105,7 @@ def train_mlp(model,
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
model, optimizer, group=group, buffer_max_size=2**21)
else:
optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model)
......@@ -140,6 +136,8 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward()
if not accumulate_grad:
......@@ -166,6 +164,7 @@ def test_dp_stage2():
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
# DP VS stage2
dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp(
......@@ -174,7 +173,8 @@ def test_dp_stage2():
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
stage2_params = train_mlp(mlp3, sharding_stage=2)
# stage2 accumulate grad
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
stage2_accumulate_grad = train_mlp(
mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True)
for i in range(len(stage2_params)):
......@@ -184,6 +184,7 @@ def test_dp_stage2():
rtol=1e-5,
atol=1e-5)
# stage2 param list VS param group
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
......
......@@ -43,13 +43,12 @@ def train_mlp(model, offload=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = ShardingScaler(scaler)
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, offload=offload)
model = ShardingStage2(
model, optimizer, buffer_max_size=2**21, accumulate_grads=False)
model = ShardingStage2(model, optimizer, buffer_max_size=2**21)
train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
......
......@@ -101,18 +101,10 @@ def train_mlp(model,
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
model, optimizer, group=group, buffer_max_size=2**21)
elif sharding_stage == 3:
model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
model, optimizer=optimizer, group=group, sync_comm=recompute)
# check optimizer.minimize() error
if test_minimize:
......@@ -231,7 +223,7 @@ def test_stage2_stage3():
stage2_params[i].numpy(),
stage3_params[i].numpy(),
rtol=1e-4,
atol=1e-4)
atol=1e-3)
# fp16 recompute
stage3_params = train_mlp(
......
......@@ -91,11 +91,7 @@ def train_mlp(model,
scaler = ShardingScaler(scaler)
model = ShardingStage3(
model,
optimizer=optimizer,
group=group,
offload=offload,
accumulate_grads=accumulate_grad)
model, optimizer=optimizer, group=group, offload=offload)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册