未验证 提交 23d559dd 编写于 作者: B Baibaifan 提交者: GitHub

optimize sharding stage3 (#39334)

上级 41eb2595
...@@ -35,6 +35,7 @@ from paddle.distributed.collective import _get_global_group ...@@ -35,6 +35,7 @@ from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad, device_guard from .sharding_utils import Type, ShardingClipGrad, device_guard
from ..pp_utils.utils import _all_gather from ..pp_utils.utils import _all_gather
from ...utils.internal_storage import GradStorage
# CUDA alignment 256 bytes # CUDA alignment 256 bytes
alignment = {"gpu": 256, } alignment = {"gpu": 256, }
...@@ -69,6 +70,7 @@ class ShardingStage3(nn.Layer): ...@@ -69,6 +70,7 @@ class ShardingStage3(nn.Layer):
group=None, group=None,
sync_buffers=False, sync_buffers=False,
device="gpu", device="gpu",
segment_size=2**15,
pertrain_sync_models=True, pertrain_sync_models=True,
accumulate_grads=False, accumulate_grads=False,
offload=False, offload=False,
...@@ -83,6 +85,8 @@ class ShardingStage3(nn.Layer): ...@@ -83,6 +85,8 @@ class ShardingStage3(nn.Layer):
self._accumulate_grads = accumulate_grads self._accumulate_grads = accumulate_grads
self._offload = offload self._offload = offload
self._sync_comm = sync_comm self._sync_comm = sync_comm
# segmentation size
self._segment_size = segment_size if not offload else 0
global DEV global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device( DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
...@@ -107,7 +111,10 @@ class ShardingStage3(nn.Layer): ...@@ -107,7 +111,10 @@ class ShardingStage3(nn.Layer):
self._param2buffer_size = dict() # {param.name: size} self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict( self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]} ) # {param.name: [(start0, end0),(start1, end1), ...]}
self._trainable_params = dict() # {layer.name: [trainable_params]} self._trainable_params = dict() # {id(layer): [trainable_params]}
self._unslice_params = set() # param's numel <= segment_size
self._unslice_params2align = dict() # {param.name: param's align}
self._grad_storages = dict() # {param.dtype: GradStorage}
assert not isinstance( assert not isinstance(
optimizer, list), "Multiple optimizers are not supported now." optimizer, list), "Multiple optimizers are not supported now."
...@@ -131,10 +138,13 @@ class ShardingStage3(nn.Layer): ...@@ -131,10 +138,13 @@ class ShardingStage3(nn.Layer):
self._segment_rank_params(self._layer) self._segment_rank_params(self._layer)
# Add unslice params to master_weight in fp16
self._handle_unslice_params()
# In the first step, record the execution order of the layer # In the first step, record the execution order of the layer
self._order_tracer = OrderedDict() self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0 self._order_tracer["order"] = 0
self._order_tracer["layer"] = [] self._order_tracer["layer"] = list()
# Register task flow # Register task flow
self._task_flow = TaskFlow() self._task_flow = TaskFlow()
...@@ -168,8 +178,10 @@ class ShardingStage3(nn.Layer): ...@@ -168,8 +178,10 @@ class ShardingStage3(nn.Layer):
def _clear_gradients(self): def _clear_gradients(self):
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
# 1.Handle param's slice
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, current_layer_params)) filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
for param in trainable_params: for param in trainable_params:
assert hasattr( assert hasattr(
param, "fw_storage" param, "fw_storage"
...@@ -178,6 +190,9 @@ class ShardingStage3(nn.Layer): ...@@ -178,6 +190,9 @@ class ShardingStage3(nn.Layer):
param.fw_storage.clear_gradient(False) param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False) param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear() param.bw_storage._clear()
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
# Update param memery slice # Update param memery slice
def _update_params_slice(self): def _update_params_slice(self):
...@@ -185,20 +200,25 @@ class ShardingStage3(nn.Layer): ...@@ -185,20 +200,25 @@ class ShardingStage3(nn.Layer):
if not isinstance(self._optim._param_groups[0], dict): if not isinstance(self._optim._param_groups[0], dict):
slice_params = [param.fw_storage for param in update_list] slice_params = [param.fw_storage for param in update_list]
self._optim._parameter_list = slice_params self._optim._parameter_list = slice_params + list(
self._optim._param_groups = slice_params self._unslice_params)
self._optim._param_groups = slice_params + list(
self._unslice_params)
else: else:
params_name_list = list(map(lambda p: p.name, update_list)) params_name_list = list(map(lambda p: p.name, update_list))
fw_storage_name_list = list(
map(lambda p: p.fw_storage.name, update_list))
for param_group in self._optim._param_groups: for param_group in self._optim._param_groups:
slice_p = [] p_group = []
for p in param_group['params']: for p in param_group['params']:
if p.name in params_name_list: if p.name in params_name_list:
assert hasattr( p_group.append(p.fw_storage)
p, "fw_storage" elif p.name in fw_storage_name_list:
), "Find {} don't have fw_storage attribute.".format( p_group.append(update_list[fw_storage_name_list.index(
p.name) p.name)].fw_storage)
slice_p.append(p.fw_storage) elif p in self._unslice_params:
param_group['params'] = slice_p p_group.append(p)
param_group['params'] = p_group
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
""" """
...@@ -213,6 +233,32 @@ class ShardingStage3(nn.Layer): ...@@ -213,6 +233,32 @@ class ShardingStage3(nn.Layer):
return fw return fw
def _handle_unslice_params(self):
buffer_size = dict()
buffer_size[Type.fp32.value] = 0
buffer_size[Type.fp16.value] = 0
for param in self._unslice_params:
# Updata optimizer master weights
if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
param2dtype[param.name] = param.dtype
p_align = self._param2align(param)
self._unslice_params2align[param.name] = p_align
buffer_size[param.dtype] += param._numel() + p_align
# Create unslice_params'grad
for param in sorted(list(self._unslice_params), key=lambda p: p.name):
if param.dtype not in self._grad_storages.keys():
self._grad_storages[param.dtype] = GradStorage(
buffer_size[param.dtype],
dtype=param.dtype,
device=self._default_device,
destination=self._rank,
parm2align=self._unslice_params2align)
self._grad_storages[param.dtype].add_grad(
param, self._unslice_params2align[param.name])
def _segment_rank_params(self, layer, name="last_layer"): def _segment_rank_params(self, layer, name="last_layer"):
""" """
Flatten parameters according to layer. Flatten parameters according to layer.
...@@ -233,24 +279,22 @@ class ShardingStage3(nn.Layer): ...@@ -233,24 +279,22 @@ class ShardingStage3(nn.Layer):
def _add_manage_info(trainable_param): def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param) return _PartitionParam(trainable_param)
trainable_params = list( current_params = list()
filter(lambda x: x.trainable, current_layer_params)) for p in current_layer_params:
if p.trainable and p._numel() > self._segment_size:
current_params.append(_add_manage_info(p))
elif p.trainable:
self._unslice_params.add(_UnsliceParam(p))
assert id(layer) not in self._trainable_params.keys() assert id(layer) not in self._trainable_params.keys()
self._trainable_params[id(layer)] = list( self._trainable_params[id(layer)] = current_params
map(_add_manage_info, trainable_params))
for param in self._trainable_params[id(layer)]: for param in self._trainable_params[id(layer)]:
if param.name in self._param2buffer.keys(): if param.name in self._param2buffer.keys():
continue continue
self._param2buffer[param.name] = [] self._param2buffer[param.name] = []
# 1.Params alignment # 1.Params alignment
offset = 0 align_ = self._param2align(param)
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
offset = align_ + param._numel() offset = align_ + param._numel()
buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - ( buffer_size = offset if offset % self._group.nranks == 0 else offset + self._group.nranks - (
...@@ -379,7 +423,9 @@ class ShardingStage3(nn.Layer): ...@@ -379,7 +423,9 @@ class ShardingStage3(nn.Layer):
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, current_layer_params)) filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
# 1.Handle param's slice
for param in trainable_params: for param in trainable_params:
assert hasattr( assert hasattr(
param, param,
...@@ -396,6 +442,19 @@ class ShardingStage3(nn.Layer): ...@@ -396,6 +442,19 @@ class ShardingStage3(nn.Layer):
assert param.fw_storage.grad is None assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage) param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param) update_list.append(param)
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.scale_(scale=self._world_size_scaling)
dist.all_reduce(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
dist.wait(
tensor=grad_storage.buffer,
group=self._group,
use_calc_stream=True)
return update_list return update_list
def get_all_parameters(self, convert2cpu=False): def get_all_parameters(self, convert2cpu=False):
...@@ -405,7 +464,8 @@ class ShardingStage3(nn.Layer): ...@@ -405,7 +464,8 @@ class ShardingStage3(nn.Layer):
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, current_layer_params)) filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
t_flow = _allgather_buffer( t_flow = _allgather_buffer(
trainable_params, trainable_params,
self._group, self._group,
...@@ -415,7 +475,7 @@ class ShardingStage3(nn.Layer): ...@@ -415,7 +475,7 @@ class ShardingStage3(nn.Layer):
offload=self._offload, offload=self._offload,
convert2cpu=convert2cpu) convert2cpu=convert2cpu)
if convert2cpu: if convert2cpu:
for param in current_layer_params: for param in trainable_params:
t_flow.full_param[param.name]._share_buffer_to(param) t_flow.full_param[param.name]._share_buffer_to(param)
self._optim._parameter_list = self._ori_parameter_list self._optim._parameter_list = self._ori_parameter_list
...@@ -424,7 +484,8 @@ class ShardingStage3(nn.Layer): ...@@ -424,7 +484,8 @@ class ShardingStage3(nn.Layer):
def _register_backward_hooks(self): def _register_backward_hooks(self):
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, current_layer_params)) filter(lambda p: p.trainable and p not in self._unslice_params,
current_layer_params))
for param in trainable_params: for param in trainable_params:
allreduce_function = self._get_allreduce_fn(param) allreduce_function = self._get_allreduce_fn(param)
...@@ -435,42 +496,36 @@ class ShardingStage3(nn.Layer): ...@@ -435,42 +496,36 @@ class ShardingStage3(nn.Layer):
def reduce(*_): def reduce(*_):
if param.name in self._task_flow.full_grad.keys(): if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name] full_grad = self._task_flow.full_grad[param.name]
with paddle.amp.auto_cast(enable=False): if not self._accumulate_grads:
if not self._accumulate_grads: full_grad.scale_(scale=self._world_size_scaling)
full_grad.scale_(scale=self._world_size_scaling) # Only support sync allreduce current rank's layer now
# Only support sync allreduce current rank's layer now dist.all_reduce(
dist.all_reduce( tensor=full_grad, group=self._group, use_calc_stream=True)
tensor=full_grad, dist.wait(
group=self._group, tensor=full_grad, group=self._group, use_calc_stream=True)
use_calc_stream=True)
dist.wait(
tensor=full_grad,
group=self._group,
use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank] start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized(): ).get_tensor()._is_initialized():
param.bw_storage = core.VarBase( param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone() full_grad._slice(start, end)).detach().clone()
if self._offload: if self._offload:
param.bw_storage = _device2cpu(param.bw_storage, param.bw_storage = _device2cpu(param.bw_storage, True)
True) else:
if self._offload:
cpu_grad = _device2cpu(
core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else: else:
if self._offload: # param.bw_storage.add_(
cpu_grad = _device2cpu( # core.VarBase(full_grad._slice(start, end))
core.VarBase(full_grad._slice(start, end)) # .detach().clone())
.detach().clone(), True) param.bw_storage = paddle.add(
param.bw_storage = paddle.add(param.bw_storage, param.bw_storage,
cpu_grad) core.VarBase(full_grad._slice(start, end)).detach(
else: ).clone())
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(
start, end)).detach().clone())
param.clear_gradient(False) param.clear_gradient(False)
param._gradient_set_empty(False) param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name) tmp_var = self._task_flow.full_grad.pop(param.name)
...@@ -493,6 +548,15 @@ class ShardingStage3(nn.Layer): ...@@ -493,6 +548,15 @@ class ShardingStage3(nn.Layer):
return reduce return reduce
def _param2align(self, param):
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = 0 if remaining == 0 else alignment[
self._default_device] - remaining
align_ = ali // align[param.dtype]
return align_
def _redefine_opt_step(self): def _redefine_opt_step(self):
params_slice_func = self._update_params_slice params_slice_func = self._update_params_slice
opt_step = self._optim.step opt_step = self._optim.step
...@@ -679,14 +743,13 @@ def _wait_layer(trainable_params, ...@@ -679,14 +743,13 @@ def _wait_layer(trainable_params,
group, group,
use_calc_stream, use_calc_stream,
offload=False): offload=False):
paddle.device.cuda.synchronize()
for param in trainable_params: for param in trainable_params:
if param.status == "all": if param.status == "all":
param.use_count += 1 param.use_count += 1
continue continue
if param.name in task_flow.full_param.keys(): if param.name in task_flow.full_param.keys():
full_param = task_flow.full_param[param.name] full_param = task_flow.full_param[param.name]
with paddle.amp.auto_cast(enable=False):
paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param) param)
param.fw_storage._clear() param.fw_storage._clear()
...@@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params, ...@@ -725,7 +788,7 @@ def _allgather_buffer(trainable_params,
full_param = _all_gather( full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream) param.fw_storage, group, use_calc_stream=use_calc_stream)
# Allgather current layer in the 1st step # Allgather current layer in the 1st step synchronously
if sync_wait: if sync_wait:
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
dist.wait( dist.wait(
...@@ -774,6 +837,12 @@ def _PartitionParam(param): ...@@ -774,6 +837,12 @@ def _PartitionParam(param):
return param return param
def _UnsliceParam(param):
if not hasattr(param, "unslice"):
setattr(param, "unslice", True)
return param
def _VarBaseWrapper(param): def _VarBaseWrapper(param):
varbase = param.fw_storage varbase = param.fw_storage
tmp_param = ParamBase( tmp_param = ParamBase(
......
...@@ -57,12 +57,15 @@ class ShardingClipGrad: ...@@ -57,12 +57,15 @@ class ShardingClipGrad:
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
sum_square_fp16 = [] sum_square_fp32, sum_square_fp16 = [], []
sum_square_fp32 = [] unslice_params_fp32, unslice_params_fp16 = [], []
for p, g in params_grads: for p, g in params_grads:
p_slice = True # using for slice parameter in sharding stage3
if g is None or getattr(p, 'need_clip', True) is False: if g is None or getattr(p, 'need_clip', True) is False:
continue continue
if hasattr(p, "unslice"):
p_slice = False
merge_grad = g merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS: if g.type == core.VarDesc.VarType.SELECTED_ROWS:
...@@ -72,9 +75,11 @@ class ShardingClipGrad: ...@@ -72,9 +75,11 @@ class ShardingClipGrad:
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
if p.dtype == paddle.float16: if p.dtype == paddle.float16:
sum_square_fp16.append(sum_square) if p_slice: sum_square_fp16.append(sum_square)
else: unslice_params_fp16.append(sum_square)
elif p.dtype == paddle.float32: elif p.dtype == paddle.float32:
sum_square_fp32.append(sum_square) if p_slice: sum_square_fp32.append(sum_square)
else: unslice_params_fp32.append(sum_square)
# global norm of non-distributed FP16 params_and_grads # global norm of non-distributed FP16 params_and_grads
if len(sum_square_fp16) == 0: if len(sum_square_fp16) == 0:
...@@ -85,12 +90,28 @@ class ShardingClipGrad: ...@@ -85,12 +90,28 @@ class ShardingClipGrad:
global_norm_fp16 = paddle.cast( global_norm_fp16 = paddle.cast(
global_norm_fp16, dtype=paddle.float32) global_norm_fp16, dtype=paddle.float32)
# global norm of non-distributed FP16 params_and_grads for slice parameter
if len(unslice_params_fp16) == 0:
global_unslice_fp16 = paddle.to_tensor([0.], dtype=paddle.float32)
else:
global_unslice_fp16 = layers.concat(unslice_params_fp16)
global_unslice_fp16 = layers.reduce_sum(global_unslice_fp16)
global_unslice_fp16 = paddle.cast(
global_unslice_fp16, dtype=paddle.float32)
# global norm of non-distributed FP32 params_and_grads # global norm of non-distributed FP32 params_and_grads
global_norm_fp32 = layers.concat(sum_square_fp32) if len( global_norm_fp32 = layers.concat(sum_square_fp32) if len(
sum_square_fp32) != 0 else paddle.to_tensor( sum_square_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32) [0.], dtype=paddle.float32)
global_norm_fp32 = layers.reduce_sum(global_norm_fp32) global_norm_fp32 = layers.reduce_sum(global_norm_fp32)
# global norm of non-distributed FP32 params_and_grads for slice parameter
global_unslice_fp32 = layers.concat(unslice_params_fp32) if len(
unslice_params_fp32) != 0 else paddle.to_tensor(
[0.], dtype=paddle.float32)
global_unslice_fp32 = layers.reduce_sum(global_unslice_fp32)
global_unslice_var = global_unslice_fp16 + global_unslice_fp32
global_norm_var = global_norm_fp16 + global_norm_fp32 global_norm_var = global_norm_fp16 + global_norm_fp32
# add all reduce to get global norm of distributed params_and_grads # add all reduce to get global norm of distributed params_and_grads
...@@ -98,6 +119,7 @@ class ShardingClipGrad: ...@@ -98,6 +119,7 @@ class ShardingClipGrad:
with device_guard(dev_id, "gpu"): with device_guard(dev_id, "gpu"):
paddle.distributed.all_reduce(global_norm_var, group=self._group) paddle.distributed.all_reduce(global_norm_var, group=self._group)
global_norm_var += global_unslice_var
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
......
...@@ -145,6 +145,10 @@ def train_mlp(model, ...@@ -145,6 +145,10 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label) input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
if not use_pure_fp16: if not use_pure_fp16:
avg_loss.backward() avg_loss.backward()
else: else:
...@@ -215,7 +219,7 @@ def test_stage2_stage3(): ...@@ -215,7 +219,7 @@ def test_stage2_stage3():
stage3_params[i].numpy(), stage3_params[i].numpy(),
stage3_params_add[i].numpy(), stage3_params_add[i].numpy(),
rtol=1e-6, rtol=1e-6,
atol=1e-6) atol=1e-4)
# fp16 # fp16
stage2_params = train_mlp( stage2_params = train_mlp(
......
...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar ...@@ -28,7 +28,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10 epoch = 10
batch_size = 32
paddle.seed(2022) paddle.seed(2022)
np.random.seed(2022) np.random.seed(2022)
base_lr = 0.1 base_lr = 0.1
...@@ -80,6 +79,7 @@ def train_mlp(model, ...@@ -80,6 +79,7 @@ def train_mlp(model,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=False, accumulate_grad=False,
offload=False, offload=False,
batch_size=100,
convert2cpu=False): convert2cpu=False):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
...@@ -91,7 +91,11 @@ def train_mlp(model, ...@@ -91,7 +91,11 @@ def train_mlp(model,
scaler = ShardingScaler(scaler) scaler = ShardingScaler(scaler)
model = ShardingStage3( model = ShardingStage3(
model, optimizer=optimizer, group=group, offload=offload) model,
optimizer=optimizer,
group=group,
offload=offload,
accumulate_grads=accumulate_grad)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True) reader_decorator(), batch_size=batch_size, drop_last=True)
...@@ -115,10 +119,15 @@ def train_mlp(model, ...@@ -115,10 +119,15 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label) input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if accumulate_grad:
avg_loss = avg_loss / 5
if not use_pure_fp16: if not use_pure_fp16:
avg_loss.backward() avg_loss.backward()
else: else:
scaler.scale(avg_loss).backward() scaler.scale(avg_loss).backward()
if not accumulate_grad: if not accumulate_grad:
if not use_pure_fp16: if not use_pure_fp16:
optimizer.step() optimizer.step()
...@@ -172,12 +181,14 @@ def test_stage3_offload(): ...@@ -172,12 +181,14 @@ def test_stage3_offload():
atol=1e-2) atol=1e-2)
# fp32 accumulate grad offload # fp32 accumulate grad offload
stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True) stage3_params = train_mlp(
mlp5, use_pure_fp16=False, batch_size=20, accumulate_grad=True)
stage3_params_offload = train_mlp( stage3_params_offload = train_mlp(
mlp6, mlp6,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=True, accumulate_grad=True,
offload=True, offload=True,
batch_size=20,
convert2cpu=True) convert2cpu=True)
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册