未验证 提交 46823104 编写于 作者: B Baibaifan 提交者: GitHub

Add sharding stage3 offload (#38989)

上级 f4623876
...@@ -33,7 +33,7 @@ from paddle.fluid.framework import ParamBase ...@@ -33,7 +33,7 @@ from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group from paddle.distributed.collective import _get_global_group
from .sharding_utils import Type, ShardingClipGrad from .sharding_utils import Type, ShardingClipGrad, device_guard
from ..pp_utils.utils import _all_gather from ..pp_utils.utils import _all_gather
# CUDA alignment 256 bytes # CUDA alignment 256 bytes
...@@ -56,6 +56,13 @@ class ShardingStage3(nn.Layer): ...@@ -56,6 +56,13 @@ class ShardingStage3(nn.Layer):
.. ZeRO: https://arxiv.org/pdf/1910.02054.pdf. .. ZeRO: https://arxiv.org/pdf/1910.02054.pdf.
""" """
# TODO (Baibaifan)
# Feature Notes::
# 1. The model supports the segmentation of parameters by global ranks in layers.
# 2. Support communication flow and computing flow.
# 3. Support offload function.
# 4. Support the establishment of independent communication groups.
def __init__(self, def __init__(self,
layer, layer,
optimizer, optimizer,
...@@ -77,6 +84,15 @@ class ShardingStage3(nn.Layer): ...@@ -77,6 +84,15 @@ class ShardingStage3(nn.Layer):
self._offload = offload self._offload = offload
self._sync_comm = sync_comm self._sync_comm = sync_comm
global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
).split(":")[0]
global DEV_ID
DEV_ID = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
.split(":")[1])
global param2dtype
param2dtype = dict()
# Communication group establishment # Communication group establishment
self._group = dist.new_group(_get_global_group() self._group = dist.new_group(_get_global_group()
.ranks) if group is None else group .ranks) if group is None else group
...@@ -85,6 +101,9 @@ class ShardingStage3(nn.Layer): ...@@ -85,6 +101,9 @@ class ShardingStage3(nn.Layer):
self._rank = self._group.rank self._rank = self._group.rank
self._global_root_rank = 0 # picking rank 0 as the reference self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks self._global_ranks = self._group.ranks
# Parameter segmentation for global ranks
# After flatten -> self._param2buffer_size, self._param2buffer, self._trainable_params
self._param2buffer_size = dict() # {param.name: size} self._param2buffer_size = dict() # {param.name: size}
self._param2buffer = dict( self._param2buffer = dict(
) # {param.name: [(start0, end0),(start1, end1), ...]} ) # {param.name: [(start0, end0),(start1, end1), ...]}
...@@ -116,12 +135,16 @@ class ShardingStage3(nn.Layer): ...@@ -116,12 +135,16 @@ class ShardingStage3(nn.Layer):
self._order_tracer = OrderedDict() self._order_tracer = OrderedDict()
self._order_tracer["order"] = 0 self._order_tracer["order"] = 0
self._order_tracer["layer"] = [] self._order_tracer["layer"] = []
# Register task flow # Register task flow
self._task_flow = TaskFlow() self._task_flow = TaskFlow()
# Register forward hooks # Register forward hooks
self._register_forward_hooks(self._layer) self._register_forward_hooks(self._layer)
# Register backward parameter hooks # Register backward parameter hooks
self._register_backward_hooks() self._register_backward_hooks()
# Redefine optimizer step and clear function # Redefine optimizer step and clear function
self._redefine_opt_step() self._redefine_opt_step()
self._redefine_opt_clear() self._redefine_opt_clear()
...@@ -152,7 +175,6 @@ class ShardingStage3(nn.Layer): ...@@ -152,7 +175,6 @@ class ShardingStage3(nn.Layer):
param, "fw_storage" param, "fw_storage"
), "Find {} don't have fw_storage attribute.".format(param.name) ), "Find {} don't have fw_storage attribute.".format(param.name)
# param.bw_storage.zero_()
param.fw_storage.clear_gradient(False) param.fw_storage.clear_gradient(False)
param.fw_storage._gradient_set_empty(False) param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear() param.bw_storage._clear()
...@@ -192,6 +214,9 @@ class ShardingStage3(nn.Layer): ...@@ -192,6 +214,9 @@ class ShardingStage3(nn.Layer):
return fw return fw
def _segment_rank_params(self, layer, name="last_layer"): def _segment_rank_params(self, layer, name="last_layer"):
"""
Flatten parameters according to layer.
"""
current_layer_params = _current_layer_params(layer) current_layer_params = _current_layer_params(layer)
if current_layer_params: if current_layer_params:
CHECK_LAYER[id(layer)] = name CHECK_LAYER[id(layer)] = name
...@@ -201,6 +226,10 @@ class ShardingStage3(nn.Layer): ...@@ -201,6 +226,10 @@ class ShardingStage3(nn.Layer):
self._segment_rank_params(sub_layer, name) self._segment_rank_params(sub_layer, name)
def _flatten_layer_params(self, layer, current_layer_params): def _flatten_layer_params(self, layer, current_layer_params):
"""
Parameter segmentation and memory integration.
"""
def _add_manage_info(trainable_param): def _add_manage_info(trainable_param):
return _PartitionParam(trainable_param) return _PartitionParam(trainable_param)
...@@ -238,8 +267,13 @@ class ShardingStage3(nn.Layer): ...@@ -238,8 +267,13 @@ class ShardingStage3(nn.Layer):
# 3.Flatten layer params and release other rank buffer # 3.Flatten layer params and release other rank buffer
self._param_storage(param, buffer_size) self._param_storage(param, buffer_size)
# Record param's dtype
param2dtype[param.name] = param.dtype
def _param_storage(self, param, buffer_size): def _param_storage(self, param, buffer_size):
"""
This is a function to simplify the handling of parameter InternalStorages.
"""
assert isinstance(buffer_size, int) assert isinstance(buffer_size, int)
value = np.zeros( value = np.zeros(
buffer_size, buffer_size,
...@@ -264,16 +298,31 @@ class ShardingStage3(nn.Layer): ...@@ -264,16 +298,31 @@ class ShardingStage3(nn.Layer):
param._clear() param._clear()
# Current rank param_storage # Current rank param_storage
param.fw_storage = core.VarBase( if self._offload:
buffer._slice(start, end), "slice@" + param.name) param.fw_storage = core.VarBase(
buffer._slice(start, end),
core.CPUPlace(), "slice@" + param.name)
else:
param.fw_storage = core.VarBase(
buffer._slice(start, end), "slice@" + param.name)
param.status = "part" param.status = "part"
# Updata optimizer master weights # Updata optimizer master weights
if param.dtype == Type.fp16.value: if param.dtype == Type.fp16.value and not self._offload:
self._optim._master_weights[param.fw_storage.name] = paddle.cast( self._optim._master_weights[param.fw_storage.name] = paddle.cast(
param.fw_storage, Type.fp32.value) param.fw_storage, Type.fp32.value)
def _register_forward_hooks(self, layer): def _register_forward_hooks(self, layer):
"""
Register pylayer to manage memory slices.
There are four stages:
FW
1. Before the forward layers, synchronize the full parameters.
2. After the forward layers, release the full parameter and keep the parameter slice.
BW
3. Before the backward layers, synchronize the full parameters and create param's grad.
4. After the gradient accumulation, release the full parameter and keep the parameter slice.
"""
current_layer_params = _current_layer_params(layer) current_layer_params = _current_layer_params(layer)
if current_layer_params: if current_layer_params:
self._register_forward_all_hooks(layer, self._task_flow) self._register_forward_all_hooks(layer, self._task_flow)
...@@ -286,13 +335,13 @@ class ShardingStage3(nn.Layer): ...@@ -286,13 +335,13 @@ class ShardingStage3(nn.Layer):
return ForwardPreHooks(layer, self._order_tracer, return ForwardPreHooks(layer, self._order_tracer,
self._trainable_params, self._param2buffer, self._trainable_params, self._param2buffer,
self._rank, self._group, self._sync_comm, self._rank, self._group, self._sync_comm,
task_flow) self._offload, task_flow)
def _forward_post_hook(layer, inputs, outputs): def _forward_post_hook(layer, inputs, outputs):
return ForwardPostHooks.apply( return ForwardPostHooks.apply(
outputs, layer, self._order_tracer, self._trainable_params, outputs, layer, self._order_tracer, self._trainable_params,
self._param2buffer, self._param2buffer_size, self._rank, self._param2buffer, self._param2buffer_size, self._rank,
self._group, self._sync_comm, task_flow) self._group, self._sync_comm, self._offload, task_flow)
# register previous forward hooks # register previous forward hooks
sub_layer.register_forward_pre_hook(_forward_pre_hook) sub_layer.register_forward_pre_hook(_forward_pre_hook)
...@@ -302,6 +351,10 @@ class ShardingStage3(nn.Layer): ...@@ -302,6 +351,10 @@ class ShardingStage3(nn.Layer):
@paddle.no_grad() @paddle.no_grad()
def _sync_buffers(self): def _sync_buffers(self):
"""
Sync all the param buffers from all ranks (exp: batch norm statistics).
"""
for buffer in self._layer.buffers(include_sublayers=True): for buffer in self._layer.buffers(include_sublayers=True):
dist.broadcast( dist.broadcast(
buffer, buffer,
...@@ -319,6 +372,9 @@ class ShardingStage3(nn.Layer): ...@@ -319,6 +372,9 @@ class ShardingStage3(nn.Layer):
return getattr(self._layer, name) return getattr(self._layer, name)
def _update_params(self): def _update_params(self):
"""
Update parameters to optimizer memory slice.
"""
update_list = [] update_list = []
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
...@@ -331,36 +387,35 @@ class ShardingStage3(nn.Layer): ...@@ -331,36 +387,35 @@ class ShardingStage3(nn.Layer):
param.name) param.name)
if self._accumulate_grads: if self._accumulate_grads:
param.bw_storage.scale_(scale=self._world_size_scaling) if self._offload:
with device_guard(device="cpu"):
param.bw_storage.scale_(scale=self._world_size_scaling)
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param) param.fw_storage = _VarBaseWrapper(param)
param.fw_storage._copy_gradient_from(param.bw_storage) param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param) update_list.append(param)
return update_list return update_list
def get_all_parameters(self): def get_all_parameters(self, convert2cpu=False):
"""
Get the full parameters and return the corresponding task flows.
"""
assert len(self._trainable_params.keys()) > 0 assert len(self._trainable_params.keys()) > 0
current_layer_params = self._layer.parameters(include_sublayers=True) current_layer_params = self._layer.parameters(include_sublayers=True)
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, current_layer_params)) filter(lambda x: x.trainable, current_layer_params))
for param in trainable_params: t_flow = _allgather_buffer(
if param.use_count > 0: trainable_params,
continue self._group,
assert hasattr( use_calc_stream=True,
param, task_flow=TaskFlow(),
"fw_storage"), "Find {} don't have fw_storage attribute".format( sync_wait=True,
param.name) offload=self._offload,
convert2cpu=convert2cpu)
full_param = _all_gather( if convert2cpu:
param.fw_storage, self._group, use_calc_stream=True) for param in current_layer_params:
dist.wait( t_flow.full_param[param.name]._share_buffer_to(param)
tensor=full_param, group=self._group, use_calc_stream=True)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear()
param.fw_storage = None
param.status = "all"
param.use_count += 1
self._optim._parameter_list = self._ori_parameter_list self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups self._optim._param_groups = self._ori_param_groups
...@@ -393,13 +448,28 @@ class ShardingStage3(nn.Layer): ...@@ -393,13 +448,28 @@ class ShardingStage3(nn.Layer):
use_calc_stream=True) use_calc_stream=True)
start, end = self._param2buffer[param.name][self._rank] start, end = self._param2buffer[param.name][self._rank]
if not self._accumulate_grads or param.bw_storage is None: if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value(
).get_tensor()._is_initialized():
param.bw_storage = core.VarBase( param.bw_storage = core.VarBase(
full_grad._slice(start, end)).detach().clone() full_grad._slice(start, end)).detach().clone()
if self._offload:
param.bw_storage = _device2cpu(param.bw_storage,
True)
else: else:
param.bw_storage.add_( if self._offload:
core.VarBase(full_grad._slice(start, end)).detach() cpu_grad = _device2cpu(
.clone()) core.VarBase(full_grad._slice(start, end))
.detach().clone(), True)
param.bw_storage = paddle.add(param.bw_storage,
cpu_grad)
else:
# param.bw_storage.add_(
# core.VarBase(full_grad._slice(start, end))
# .detach().clone())
param.bw_storage = paddle.add(
param.bw_storage,
core.VarBase(full_grad._slice(
start, end)).detach().clone())
param.clear_gradient(False) param.clear_gradient(False)
param._gradient_set_empty(False) param._gradient_set_empty(False)
tmp_var = self._task_flow.full_grad.pop(param.name) tmp_var = self._task_flow.full_grad.pop(param.name)
...@@ -410,15 +480,16 @@ class ShardingStage3(nn.Layer): ...@@ -410,15 +480,16 @@ class ShardingStage3(nn.Layer):
param.use_count = 0 param.use_count = 0
param._clear() param._clear()
start, end = self._param2buffer[param.name][self._rank] start, end = self._param2buffer[param.name][self._rank]
with paddle.amp.auto_cast(enable=False): param.fw_storage = core.VarBase(
param.fw_storage = core.VarBase( self._task_flow.full_param[param.name]._slice(
self._task_flow.full_param[param.name]._slice(start, start, end), param.name + "@slice").detach().clone()
end),
param.name + "@slice").detach().clone()
param.status = "part" param.status = "part"
tmp_var = self._task_flow.full_param.pop(param.name) tmp_var = self._task_flow.full_param.pop(param.name)
tmp_var._clear() tmp_var._clear()
if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True)
return reduce return reduce
def _redefine_opt_step(self): def _redefine_opt_step(self):
...@@ -429,7 +500,11 @@ class ShardingStage3(nn.Layer): ...@@ -429,7 +500,11 @@ class ShardingStage3(nn.Layer):
def _opt_step(self): def _opt_step(self):
if not update_scaler: if not update_scaler:
params_slice_func() params_slice_func()
opt_step() if self.offload:
with device_guard(device="cpu"):
opt_step()
else:
opt_step()
self._optim.step = MethodType(_opt_step, self._optim) self._optim.step = MethodType(_opt_step, self._optim)
...@@ -443,7 +518,7 @@ class ShardingStage3(nn.Layer): ...@@ -443,7 +518,7 @@ class ShardingStage3(nn.Layer):
def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank, def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
group, sync_comm, task_flow): group, sync_comm, offload, task_flow):
# Record layer's id # Record layer's id
layer_id = id(layer) layer_id = id(layer)
...@@ -451,21 +526,28 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank, ...@@ -451,21 +526,28 @@ def ForwardPreHooks(layer, order_tracer, trainable_params, param2buffer, rank,
if layer_id not in order_tracer.keys() or sync_comm: if layer_id not in order_tracer.keys() or sync_comm:
use_calc, sync_wait = True, True use_calc, sync_wait = True, True
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc task_flow.use_calc[layer_id] = use_calc
else: else:
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc task_flow.use_calc[layer_id] = use_calc
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc) # wait current layer params
_wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
offload)
if layer_id == order_tracer["layer"][-1]: return if layer_id == order_tracer["layer"][-1]: return
order_ = order_tracer[layer_id] order_ = order_tracer[layer_id]
layer_id = order_tracer["layer"][order_ + 1] layer_id = order_tracer["layer"][order_ + 1]
_allgather_buffer( _allgather_buffer(
layer_id, trainable_params[layer_id],
trainable_params,
group, group,
use_calc_stream=use_calc, use_calc_stream=use_calc,
task_flow=task_flow, task_flow=task_flow,
sync_wait=sync_wait) sync_wait=sync_wait,
offload=offload)
return return
...@@ -473,15 +555,20 @@ class ForwardPostHooks(PyLayer): ...@@ -473,15 +555,20 @@ class ForwardPostHooks(PyLayer):
@staticmethod @staticmethod
def forward(ctx, inputs, layer, order_tracer, trainable_params, def forward(ctx, inputs, layer, order_tracer, trainable_params,
param2buffer, param2buffer_size, rank, group, sync_comm, param2buffer, param2buffer_size, rank, group, sync_comm,
task_flow): offload, task_flow):
_release_param(layer, trainable_params, param2buffer, rank, task_flow)
layer_id = id(layer) layer_id = id(layer)
# release current layer full params
_release_param(trainable_params[layer_id], param2buffer, rank,
task_flow, offload)
if layer_id not in order_tracer.keys(): if layer_id not in order_tracer.keys():
order_ = order_tracer["order"] order_ = order_tracer["order"]
order_tracer[layer_id] = order_ order_tracer[layer_id] = order_
order_tracer["order"] += 1 order_tracer["order"] += 1
order_tracer["layer"].append(layer_id) order_tracer["layer"].append(layer_id)
#Record bw info
ctx.order_tracer = order_tracer ctx.order_tracer = order_tracer
ctx.task_flow = task_flow ctx.task_flow = task_flow
ctx.group = group ctx.group = group
...@@ -489,6 +576,7 @@ class ForwardPostHooks(PyLayer): ...@@ -489,6 +576,7 @@ class ForwardPostHooks(PyLayer):
ctx.sync_comm = sync_comm ctx.sync_comm = sync_comm
ctx.trainable_params = trainable_params ctx.trainable_params = trainable_params
ctx.param2buffer_size = param2buffer_size ctx.param2buffer_size = param2buffer_size
ctx.offload = offload
return inputs return inputs
...@@ -502,31 +590,39 @@ class ForwardPostHooks(PyLayer): ...@@ -502,31 +590,39 @@ class ForwardPostHooks(PyLayer):
trainable_params = ctx.trainable_params trainable_params = ctx.trainable_params
param2buffer_size = ctx.param2buffer_size param2buffer_size = ctx.param2buffer_size
sync_comm = ctx.sync_comm sync_comm = ctx.sync_comm
offload = ctx.offload
layer_id = id(layer) layer_id = id(layer)
use_calc, sync_wait = False, False use_calc, sync_wait = False, False
# Allgather params synchronization
if sync_comm: if sync_comm:
use_calc, sync_wait = True, True use_calc, sync_wait = True, True
_allgather_buffer( _allgather_buffer(
layer_id, trainable_params[layer_id],
trainable_params,
group, group,
use_calc_stream=use_calc, use_calc_stream=use_calc,
task_flow=task_flow, task_flow=task_flow,
sync_wait=sync_wait) sync_wait=sync_wait,
offload=offload)
else: else:
_wait_layer(trainable_params, layer_id, task_flow, group, use_calc) _wait_layer(trainable_params[layer_id], task_flow, group, use_calc,
_create_params_grad(layer, trainable_params, param2buffer_size, offload)
# Create params's grad
_create_params_grad(trainable_params[layer_id], param2buffer_size,
task_flow) task_flow)
# Whether to use calc stream
task_flow.use_calc[layer_id] = use_calc task_flow.use_calc[layer_id] = use_calc
if layer_id != order_tracer["layer"][0] and not sync_comm: if layer_id != order_tracer["layer"][0] and not sync_comm:
layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1] layer_next_id = order_tracer["layer"][order_tracer[layer_id] - 1]
_allgather_buffer( _allgather_buffer(
layer_next_id, trainable_params[layer_next_id],
trainable_params,
group, group,
use_calc_stream=use_calc, use_calc_stream=use_calc,
task_flow=task_flow, task_flow=task_flow,
sync_wait=sync_wait) sync_wait=sync_wait,
offload=offload)
return args return args
...@@ -547,8 +643,12 @@ class TaskFlow: ...@@ -547,8 +643,12 @@ class TaskFlow:
self.callback = callback self.callback = callback
def _release_param(layer, trainable_params, param2buffer, rank, task_flow): def _release_param(trainable_params,
for param in trainable_params[id(layer)]: param2buffer,
rank,
task_flow,
offload=False):
for param in trainable_params:
# async communicate share weight not clear # async communicate share weight not clear
param.use_count -= 1 param.use_count -= 1
if param.use_count == 0: if param.use_count == 0:
...@@ -562,11 +662,18 @@ def _release_param(layer, trainable_params, param2buffer, rank, task_flow): ...@@ -562,11 +662,18 @@ def _release_param(layer, trainable_params, param2buffer, rank, task_flow):
param.status = "part" param.status = "part"
tmp_var = task_flow.full_param.pop(param.name) tmp_var = task_flow.full_param.pop(param.name)
tmp_var._clear() tmp_var._clear()
if offload:
param.fw_storage = _device2cpu(param.fw_storage)
return return
def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream): def _wait_layer(trainable_params,
for param in trainable_params[layer_id]: task_flow,
group,
use_calc_stream,
offload=False):
for param in trainable_params:
if param.status == "all": if param.status == "all":
param.use_count += 1 param.use_count += 1
continue continue
...@@ -576,36 +683,43 @@ def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream): ...@@ -576,36 +683,43 @@ def _wait_layer(trainable_params, layer_id, task_flow, group, use_calc_stream):
paddle.device.cuda.synchronize() paddle.device.cuda.synchronize()
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param) param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear() param.fw_storage._clear()
param.fw_storage = None param.fw_storage = None
param.status = "all" param.status = "all"
param.use_count += 1 param.use_count += 1
else: else:
_allgather_buffer( _allgather_buffer(
layer_id,
trainable_params, trainable_params,
group, group,
use_calc_stream, use_calc_stream=True,
task_flow, task_flow=task_flow,
sync_wait=True) sync_wait=True,
offload=offload)
break break
return task_flow return task_flow
def _allgather_buffer(layer_id, def _allgather_buffer(trainable_params,
trainable_params,
group, group,
use_calc_stream, use_calc_stream,
task_flow, task_flow,
sync_wait=False): sync_wait=False,
for param in trainable_params[layer_id]: offload=False,
convert2cpu=False):
for param in trainable_params:
if param.status == "all": if param.status == "all":
param.use_count += 1 param.use_count += 1
continue continue
if offload:
param.fw_storage = _cpu2device(param)
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
full_param = _all_gather( full_param = _all_gather(
param.fw_storage, group, use_calc_stream=use_calc_stream) param.fw_storage, group, use_calc_stream=use_calc_stream)
# Allgather current layer in the 1st step
if sync_wait: if sync_wait:
with paddle.amp.auto_cast(enable=False): with paddle.amp.auto_cast(enable=False):
dist.wait( dist.wait(
...@@ -614,18 +728,26 @@ def _allgather_buffer(layer_id, ...@@ -614,18 +728,26 @@ def _allgather_buffer(layer_id,
use_calc_stream=use_calc_stream) use_calc_stream=use_calc_stream)
core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to( core.VarBase(full_param._slice(0, param._numel()))._share_buffer_to(
param) param)
param.value().get_tensor()._set_dims(param.shape)
param.fw_storage._clear() param.fw_storage._clear()
param.fw_storage = None param.fw_storage = None
param.status = "all" param.status = "all"
param.use_count += 1 param.use_count += 1
task_flow.full_param[param.name] = full_param task_flow.full_param[param.name] = full_param
# parameter converts to cpu
if convert2cpu:
p_name = param.name
param = _device2cpu(param)
tmp_var = task_flow.full_param.pop(p_name)
tmp_var._clear()
task_flow.full_param[p_name] = param
return task_flow return task_flow
@paddle.no_grad() @paddle.no_grad()
def _create_params_grad(layer, trainable_params, param2buffer_size, task_flow): def _create_params_grad(trainable_params, param2buffer_size, task_flow):
for param in trainable_params[id(layer)]: for param in trainable_params:
if param.name in task_flow.full_grad.keys(): if param.name in task_flow.full_grad.keys():
continue continue
assert isinstance(param2buffer_size[param.name], int) assert isinstance(param2buffer_size[param.name], int)
...@@ -668,6 +790,23 @@ def _OptimizerWrapper(optimizer, offload, group, update_params_slice): ...@@ -668,6 +790,23 @@ def _OptimizerWrapper(optimizer, offload, group, update_params_slice):
return optimizer return optimizer
def _device2cpu(trans_param, convert_dtype=False):
if convert_dtype:
trans_param = paddle.cast(trans_param, Type.fp32.value)
tmp_p = trans_param.cpu()
trans_param._clear()
return tmp_p
def _cpu2device(param):
tmp_p = param.fw_storage.cuda(DEV_ID)
param.fw_storage._clear()
if tmp_p.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_p = paddle.cast(tmp_p, Type.fp16.value)
return tmp_p
def _current_layer_params(layer): def _current_layer_params(layer):
return layer.parameters( return layer.parameters(
include_sublayers=False) + list(layer.extra_parameters) if hasattr( include_sublayers=False) + list(layer.extra_parameters) if hasattr(
......
...@@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar ...@@ -30,7 +30,6 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import Shar
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10 epoch = 10
batch_size = 32
paddle.seed(2021) paddle.seed(2021)
np.random.seed(2021) np.random.seed(2021)
base_lr = 0.1 base_lr = 0.1
...@@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000): ...@@ -66,10 +65,10 @@ def reader_decorator(linear_size=1000):
def optimizer_setting(model, use_pure_fp16, opt_group=False): def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.Momentum(
parameters=[{ parameters=[{
"params": model.parameters() "params": list(model.parameters())
}] if opt_group else model.parameters(), }] if opt_group else list(model.parameters()),
learning_rate=0.001, learning_rate=0.001,
weight_decay=0.00001, weight_decay=0.00001,
grad_clip=clip, grad_clip=clip,
...@@ -82,6 +81,7 @@ def train_mlp(model, ...@@ -82,6 +81,7 @@ def train_mlp(model,
sharding_stage, sharding_stage,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=False, accumulate_grad=False,
batch_size=100,
opt_group=False, opt_group=False,
recompute=False): recompute=False):
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
...@@ -104,10 +104,14 @@ def train_mlp(model, ...@@ -104,10 +104,14 @@ def train_mlp(model,
optimizer, optimizer,
group=group, group=group,
buffer_max_size=2**21, buffer_max_size=2**21,
accumulate_grads=accumulate_grad) accumulate_grads=batch_size == 20)
elif sharding_stage == 3: elif sharding_stage == 3:
model = ShardingStage3( model = ShardingStage3(
model, optimizer=optimizer, group=group, sync_comm=recompute) model,
optimizer=optimizer,
group=group,
accumulate_grads=batch_size == 20,
sync_comm=recompute)
train_reader = paddle.batch( train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True) reader_decorator(), batch_size=batch_size, drop_last=True)
...@@ -131,21 +135,22 @@ def train_mlp(model, ...@@ -131,21 +135,22 @@ def train_mlp(model,
loss = paddle.nn.functional.cross_entropy( loss = paddle.nn.functional.cross_entropy(
input=out, label=label) input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad: if not accumulate_grad:
if not use_pure_fp16: if not use_pure_fp16:
avg_loss.backward()
optimizer.step() optimizer.step()
else: else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.clear_grad() optimizer.clear_grad()
if accumulate_grad: if accumulate_grad:
if not use_pure_fp16: if not use_pure_fp16:
avg_loss.backward()
optimizer.step() optimizer.step()
else: else:
scaler.scale(avg_loss).backward()
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.clear_grad() optimizer.clear_grad()
...@@ -168,48 +173,50 @@ def test_stage2_stage3(): ...@@ -168,48 +173,50 @@ def test_stage2_stage3():
mlp8.set_state_dict(state_dict) mlp8.set_state_dict(state_dict)
# fp32 # fp32
stage2_params = train_mlp( stage2_params = train_mlp(
mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=True) mlp1, sharding_stage=2, use_pure_fp16=False, opt_group=False)
stage3_params = train_mlp( stage3_params = train_mlp(
mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=True) mlp2, sharding_stage=3, use_pure_fp16=False, opt_group=False)
for i in range(len(stage2_params)): for i in range(len(stage2_params)):
for j in range(len(stage3_params)): np.testing.assert_allclose(
if stage2_params[i].name == stage3_params[j].name: stage2_params[i].numpy(),
np.testing.assert_allclose( stage3_params[i].numpy(),
stage2_params[i].numpy(), rtol=1e-6,
stage3_params[j].numpy(), atol=1e-6)
rtol=1e-6)
# fp32 accumulate grad # fp32 accumulate grad
stage2_params = train_mlp( stage3_params = train_mlp(
mlp3, mlp3,
sharding_stage=2, sharding_stage=3,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=True, accumulate_grad=True,
opt_group=True) opt_group=True)
stage3_params = train_mlp( stage3_params_add = train_mlp(
mlp4, mlp4,
sharding_stage=3, sharding_stage=3,
use_pure_fp16=False, use_pure_fp16=False,
accumulate_grad=True, accumulate_grad=True,
batch_size=20,
opt_group=True) opt_group=True)
for i in range(len(stage2_params)): for i in range(len(stage3_params)):
for j in range(len(stage3_params)): np.testing.assert_allclose(
if stage2_params[i].name == stage3_params[j].name: stage3_params[i].numpy(),
np.testing.assert_allclose( stage3_params_add[i].numpy(),
stage2_params[i].numpy(), rtol=1e-6,
stage3_params[j].numpy(), atol=1e-6)
rtol=1e-6)
# fp16 # fp16
stage2_params = train_mlp( stage2_params = train_mlp(
mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False) mlp5, sharding_stage=2, use_pure_fp16=True, opt_group=False)
stage3_params = train_mlp( stage3_params = train_mlp(
mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False) mlp6, sharding_stage=3, use_pure_fp16=True, opt_group=False)
for i in range(len(stage2_params)): for i in range(len(stage2_params)):
for j in range(len(stage3_params)): np.testing.assert_allclose(
if stage2_params[i].name == stage3_params[j].name: stage2_params[i].numpy(),
np.testing.assert_allclose( stage3_params[i].numpy(),
stage2_params[i].numpy(), rtol=1e-4,
stage3_params[j].numpy(), atol=1e-4)
rtol=1e-6)
# fp16 recompute # fp16 recompute
stage3_params = train_mlp( stage3_params = train_mlp(
mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False) mlp7, sharding_stage=3, use_pure_fp16=True, opt_group=False)
...@@ -220,12 +227,8 @@ def test_stage2_stage3(): ...@@ -220,12 +227,8 @@ def test_stage2_stage3():
opt_group=False, opt_group=False,
recompute=True) recompute=True)
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
for j in range(len(stage3_params_re)): np.testing.assert_allclose(
if stage3_params[i].name == stage3_params_re[j].name: stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_re[j].numpy(),
rtol=1e-6)
return return
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import argparse
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler
epoch = 10
batch_size = 32
paddle.seed(2022)
np.random.seed(2022)
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
fleet.init(is_collective=True)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
use_pure_fp16=False,
accumulate_grad=False,
offload=False,
convert2cpu=False):
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
scaler = paddle.amp.GradScaler(init_loss_scaling=32768)
scaler = ShardingScaler(scaler)
model = ShardingStage3(
model, optimizer=optimizer, group=group, offload=offload)
train_reader = paddle.batch(
reader_decorator(), batch_size=batch_size, drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(
capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(True, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if not use_pure_fp16:
avg_loss.backward()
else:
scaler.scale(avg_loss).backward()
if not accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if accumulate_grad:
if not use_pure_fp16:
optimizer.step()
else:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
if not convert2cpu:
model.get_all_parameters()
else:
model.get_all_parameters(convert2cpu)
return model.parameters()
def test_stage3_offload():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6 = MLP(), MLP(), MLP(), MLP(), MLP(
), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
# fp32 offload
stage3_params = train_mlp(mlp1, use_pure_fp16=False)
stage3_params_offload = train_mlp(mlp2, use_pure_fp16=False, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
# fp16 offload
stage3_params = train_mlp(mlp3, use_pure_fp16=True)
stage3_params_offload = train_mlp(mlp4, use_pure_fp16=True, offload=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-2,
atol=1e-2)
# fp32 accumulate grad offload
stage3_params = train_mlp(mlp5, use_pure_fp16=False, accumulate_grad=True)
stage3_params_offload = train_mlp(
mlp6,
use_pure_fp16=False,
accumulate_grad=True,
offload=True,
convert2cpu=True)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
rtol=1e-6,
atol=1e-8)
return
if __name__ == '__main__':
test_stage3_offload()
...@@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -23,10 +23,10 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage2(TestMultipleGpus): class TestDygraphShardingStage2(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage2(self): def test_dygraph_sharding_stage2(self):
self.run_mnist_2gpu('dygraph_sharding_stage2.py') self.run_mnist_2gpu('dygraph_sharding_stage2.py')
def test_dygraph_sharding_optimizer_stage2_offload(self): def test_dygraph_sharding_stage2_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py') self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py')
......
...@@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus ...@@ -23,9 +23,12 @@ from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestDygraphShardingStage3(TestMultipleGpus): class TestDygraphShardingStage3(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode # check sharding logic as well as the accuracy with single mode
def test_dygraph_sharding_optimizer_stage3(self): def test_dygraph_sharding_stage3(self):
self.run_mnist_2gpu('dygraph_sharding_stage3.py') self.run_mnist_2gpu('dygraph_sharding_stage3.py')
def test_dygraph_sharding_stage3_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage3_offload.py')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册