未验证 提交 18c6f40b 编写于 作者: B Baibaifan 提交者: GitHub

optimizer sharding paramters (#39581)

上级 1f7f8561
...@@ -65,9 +65,9 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -65,9 +65,9 @@ class ShardingOptimizerStage2(Optimizer):
params, params,
optim, optim,
group=None, group=None,
broadcast_fp16=False,
offload=False, offload=False,
device="gpu", device="gpu",
pertrain_sync_models=True,
**kw): **kw):
super().__init__(optim._learning_rate, params, kw) super().__init__(optim._learning_rate, params, kw)
...@@ -98,8 +98,12 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -98,8 +98,12 @@ class ShardingOptimizerStage2(Optimizer):
self.world_size = self.group.nranks self.world_size = self.group.nranks
self.rank = self.group.rank self.rank = self.group.rank
self._global_root_rank = 0
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
self.broadcast_fp16 = broadcast_fp16
self.param_storages = {} # {dtype: {rank: InternalStorage}} self.param_storages = {} # {dtype: {rank: InternalStorage}}
if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm): if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm):
...@@ -132,6 +136,22 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -132,6 +136,22 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status() self._update_opt_status()
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for p in self._local_params:
dist.broadcast(
p,
src=self._global_root_rank,
group=self.group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=p, group=self.group, use_calc_stream=True)
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
for param in trainable_params: for param in trainable_params:
......
...@@ -61,12 +61,10 @@ class ShardingStage2(nn.Layer): ...@@ -61,12 +61,10 @@ class ShardingStage2(nn.Layer):
sharding_optimizer, sharding_optimizer,
group=None, group=None,
sync_buffers=False, sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB buffer_max_size=2**23, #8MB
auto_refresh_trainable=True, auto_refresh_trainable=True,
device="gpu", device="gpu",
use_grad_storage=True, use_grad_storage=True):
accumulate_grads=False):
super().__init__() super().__init__()
# training options # training options
...@@ -81,9 +79,6 @@ class ShardingStage2(nn.Layer): ...@@ -81,9 +79,6 @@ class ShardingStage2(nn.Layer):
self._sync_buffers = sync_buffers self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable self._auto_refresh_trainable = auto_refresh_trainable
# Gradient accumulation, Gradient flip
self._accumulate_grads = accumulate_grads
# Communication related attributes # Communication related attributes
self._group = dist.new_group(_get_global_group() self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group .ranks) if group is None else group
...@@ -128,16 +123,11 @@ class ShardingStage2(nn.Layer): ...@@ -128,16 +123,11 @@ class ShardingStage2(nn.Layer):
# Set backward pass hooks # Set backward pass hooks
self._bw_hooks = [] self._bw_hooks = []
# Synchronous all ranks models
if pertrain_sync_models:
self._sync_params_and_buffers()
# Set tasks flow # Set tasks flow
self._tasks_flow = deque() self._tasks_flow = deque()
# Define optimizer step and clear_grad # Define optimizer step and clear_grad
if self._accumulate_grads: self._redefine_opt_step()
self._redefine_opt_step()
self._redefine_opt_clear() self._redefine_opt_clear()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -313,9 +303,6 @@ class ShardingStage2(nn.Layer): ...@@ -313,9 +303,6 @@ class ShardingStage2(nn.Layer):
# Change reduce information # Change reduce information
self._grad_reduced[index] = False self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version(True)
# Clear the gradient that does not belong to the current rank through the callback function # Clear the gradient that does not belong to the current rank through the callback function
def cleanup(): def cleanup():
...@@ -362,11 +349,6 @@ class ShardingStage2(nn.Layer): ...@@ -362,11 +349,6 @@ class ShardingStage2(nn.Layer):
if grad_storage.all_checked_in: if grad_storage.all_checked_in:
assert grad_storage.buffer is not None assert grad_storage.buffer is not None
# Normalize all ranks grad_storage
if not self._accumulate_grads:
grad_storage.buffer.scale_(
scale=self._world_size_scaling)
# Clearing up the grad_storage buffer # Clearing up the grad_storage buffer
def cleanup(): def cleanup():
if dst_rank != self._rank: if dst_rank != self._rank:
...@@ -432,22 +414,6 @@ class ShardingStage2(nn.Layer): ...@@ -432,22 +414,6 @@ class ShardingStage2(nn.Layer):
self._bw_hooks.append( self._bw_hooks.append(
param._register_backward_hook(reduce_function)) param._register_backward_hook(reduce_function))
@paddle.no_grad()
def _sync_params_and_buffers(self):
"""
Sync all model states for all ranks
"""
for t in self._layer.parameters():
dist.broadcast(
t,
src=self._global_root_rank,
group=self._group,
use_calc_stream=True)
# Multi stream operation will be supported later
dist.wait(tensor=t, group=self._group, use_calc_stream=True)
def _setup_use_grad_storage(self): def _setup_use_grad_storage(self):
""" """
Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters. Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters.
...@@ -555,8 +521,6 @@ class ShardingStage2(nn.Layer): ...@@ -555,8 +521,6 @@ class ShardingStage2(nn.Layer):
return rank_buffer_size return rank_buffer_size
def _redefine_opt_step(self): def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale grad_func = self._grad_scale
for opt in self._sharding_optimizers: for opt in self._sharding_optimizers:
opt_step = opt.step opt_step = opt.step
......
...@@ -72,7 +72,6 @@ class ShardingStage3(nn.Layer): ...@@ -72,7 +72,6 @@ class ShardingStage3(nn.Layer):
device="gpu", device="gpu",
segment_size=2**15, segment_size=2**15,
pertrain_sync_models=True, pertrain_sync_models=True,
accumulate_grads=False,
offload=False, offload=False,
sync_comm=False): sync_comm=False):
super().__init__() super().__init__()
...@@ -82,7 +81,6 @@ class ShardingStage3(nn.Layer): ...@@ -82,7 +81,6 @@ class ShardingStage3(nn.Layer):
self._layer = layer self._layer = layer
self._default_device = device self._default_device = device
self.__sync_buffers = sync_buffers self.__sync_buffers = sync_buffers
self._accumulate_grads = accumulate_grads
self._offload = offload self._offload = offload
self._sync_comm = sync_comm self._sync_comm = sync_comm
# segmentation size # segmentation size
...@@ -190,6 +188,7 @@ class ShardingStage3(nn.Layer): ...@@ -190,6 +188,7 @@ class ShardingStage3(nn.Layer):
param.fw_storage.clear_gradient(False) param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False) param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear() param.bw_storage._clear()
param.bw_storage = None
# 2.Handle unslice param # 2.Handle unslice param
if not self._offload: if not self._offload:
for grad_storage in self._grad_storages.values(): for grad_storage in self._grad_storages.values():
...@@ -446,13 +445,12 @@ class ShardingStage3(nn.Layer): ...@@ -446,13 +445,12 @@ class ShardingStage3(nn.Layer):
param, param,
"fw_storage"), "Find {} don't have fw_storage attribute".format( "fw_storage"), "Find {} don't have fw_storage attribute".format(
param.name) param.name)
# Gradient average
if self._accumulate_grads: if self._offload:
if self._offload: with device_guard(device="cpu"):
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling) param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param) param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage) param.fw_storage._copy_gradient_from(param.bw_storage)
...@@ -526,8 +524,6 @@ class ShardingStage3(nn.Layer): ...@@ -526,8 +524,6 @@ class ShardingStage3(nn.Layer):
def reduce(*_): def reduce(*_):
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
if not self._accumulate_grads:
full_grad.scale_(scale=self._world_size_scaling)
# Only support sync allreduce current rank's layer now # Only support sync allreduce current rank's layer now
dist.all_reduce( dist.all_reduce(
tensor=full_grad, group=self._group, use_calc_stream=True) tensor=full_grad, group=self._group, use_calc_stream=True)
...@@ -535,8 +531,7 @@ class ShardingStage3(nn.Layer): ...@@ -535,8 +531,7 @@ class ShardingStage3(nn.Layer):
tensor=full_grad, group=self._group, use_calc_stream=True) tensor=full_grad, group=self._group, use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank] start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( if param.bw_storage is None:
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase( param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone() full_grad._slice(start, end)).detach().clone()
if self._offload: if self._offload:
......
...@@ -27,7 +27,7 @@ from paddle.fluid.dygraph import nn ...@@ -27,7 +27,7 @@ from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
seed = 2021 seed = 2022
epoch = 2 epoch = 2
linear_size = 1000 linear_size = 1000
...@@ -105,11 +105,7 @@ def train_mlp(model, ...@@ -105,11 +105,7 @@ def train_mlp(model,
params=model.parameters(), optim=optimizer, group=group) params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2( model = ShardingStage2(
model, model, optimizer, group=group, buffer_max_size=2**21)
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
else: else:
optimizer = fleet.distributed_optimizer(optimizer) optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
...@@ -140,6 +136,8 @@ def train_mlp(model, ...@@ -140,6 +136,8 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy(input=out, label=label) loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward() avg_loss.backward()
if not accumulate_grad: if not accumulate_grad:
...@@ -166,6 +164,7 @@ def test_dp_stage2(): ...@@ -166,6 +164,7 @@ def test_dp_stage2():
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict) mlp5.set_state_dict(state_dict)
# DP VS stage2
dp_params = train_mlp( dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False) mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False)
stage2_params = train_mlp( stage2_params = train_mlp(
...@@ -174,7 +173,8 @@ def test_dp_stage2(): ...@@ -174,7 +173,8 @@ def test_dp_stage2():
np.testing.assert_allclose( np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
stage2_params = train_mlp(mlp3, sharding_stage=2) # stage2 accumulate grad
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
stage2_accumulate_grad = train_mlp( stage2_accumulate_grad = train_mlp(
mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True) mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True)
for i in range(len(stage2_params)): for i in range(len(stage2_params)):
...@@ -184,6 +184,7 @@ def test_dp_stage2(): ...@@ -184,6 +184,7 @@ def test_dp_stage2():
rtol=1e-5, rtol=1e-5,
atol=1e-5) atol=1e-5)
# stage2 param list VS param group
stage2_params = train_mlp( stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)): for i in range(len(dp_params)):
......
...@@ -43,13 +43,12 @@ def train_mlp(model, offload=False): ...@@ -43,13 +43,12 @@ def train_mlp(model, offload=False):
optimizer = optimizer_setting(model=model, use_pure_fp16=True) optimizer = optimizer_setting(model=model, use_pure_fp16=True)
model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768) scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
scaler = ShardingScaler(scaler) scaler = ShardingScaler(scaler)
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, offload=offload) params=model.parameters(), optim=optimizer, offload=offload)
model = ShardingStage2( model = ShardingStage2(model, optimizer, buffer_max_size=2**21)
model, optimizer, buffer_max_size=2**21, accumulate_grads=False)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True) reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
......
...@@ -101,18 +101,10 @@ def train_mlp(model, ...@@ -101,18 +101,10 @@ def train_mlp(model,
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group) params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2( model = ShardingStage2(
model, model, optimizer, group=group, buffer_max_size=2**21)
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=batch_size == 20)
elif sharding_stage == 3: elif sharding_stage == 3:
model = ShardingStage3( model = ShardingStage3(
model, model, optimizer=optimizer, group=group, sync_comm=recompute)
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
# check optimizer.minimize() error # check optimizer.minimize() error
if test_minimize: if test_minimize:
...@@ -231,7 +223,7 @@ def test_stage2_stage3(): ...@@ -231,7 +223,7 @@ def test_stage2_stage3():
stage2_params[i].numpy(), stage2_params[i].numpy(),
stage3_params[i].numpy(), stage3_params[i].numpy(),
rtol=1e-4, rtol=1e-4,
atol=1e-4) atol=1e-3)
# fp16 recompute # fp16 recompute
stage3_params = train_mlp( stage3_params = train_mlp(
......
...@@ -91,11 +91,7 @@ def train_mlp(model, ...@@ -91,11 +91,7 @@ def train_mlp(model,
scaler = ShardingScaler(scaler) scaler = ShardingScaler(scaler)
model = ShardingStage3( model = ShardingStage3(
model, model, optimizer=optimizer, group=group, offload=offload)
optimizer=optimizer,
group=group,
offload=offload,
accumulate_grads=accumulate_grad)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True) reader_decorator(), batch_size=batch_size, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册