未验证 提交 d8b4ca92 编写于 作者: Y Yuang Liu 提交者: GitHub

[dygraph sharding stage 2] sharding broadcast overlap (#46656)

上级 9a849a37
......@@ -24,6 +24,8 @@
import copy
import logging
import warnings
import numpy as np
from collections import OrderedDict
......@@ -87,7 +89,7 @@ class GroupShardedOptimizerStage2(Optimizer):
self._optim = optim
# sharing stage 2 comm overlap flag
self._comm_overlap = False
self._reduce_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None
......@@ -108,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer):
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0
self._broadcast_overlap = False
self._forward_pre_hook_remove_helper = []
try:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self._broadcast_order_params = sorted(
self.local_params,
key=lambda x: int(x.name.split('.')[0].split('_')[-1]))
except ValueError:
self._broadcast_order_params = None
self._group = new_group(
_get_global_group().ranks) if group is None else group
......@@ -163,15 +176,34 @@ class GroupShardedOptimizerStage2(Optimizer):
sync_op=True)
def _update_task(self, task):
if self._comm_overlap:
if self._reduce_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task
def _set_comm_overlap(self, comm_overlap):
self._comm_overlap = comm_overlap
def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap
def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
assert layers is not None, \
"To enable broadcast overlap forward, please pass the module to the function."
self._layers = layers
warnings.warn(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if self._broadcast_order_params is None:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings.warn(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params
def _generate_master_params(self, trainable_params):
if self.offload:
......@@ -382,6 +414,12 @@ class GroupShardedOptimizerStage2(Optimizer):
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
hook_remove.remove()
self._forward_pre_hook_remove_helper = []
if self.offload:
params_list = [self.offload_params.buffer]
......@@ -425,9 +463,49 @@ class GroupShardedOptimizerStage2(Optimizer):
"""Broadcast the parameters of the current rank to each rank"""
# Exchange all the shards with the other ranks
if self._broadcast_overlap:
self._broadcast_params_overlap_forward()
else:
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)
def _forward_pre_hook_function(self, tasks):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def __impl__(x, y):
for task in tasks:
# Wait for broadcast task before using the result of the broadcast.
task.wait()
return __impl__
@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
task = broadcast(
tensor=x,
src=self._group.ranks[self._param2rank[x.name]],
group=self._group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task
for layer in self._layers.sublayers():
if len(layer.sublayers()) == 0:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks = []
for param in layer.parameters():
if param.trainable:
if param.name in param2task:
tasks.append(param2task[param.name])
self._forward_pre_hook_remove_helper.append(
layer.register_forward_pre_hook(
self._forward_pre_hook_function(tasks)))
......@@ -101,7 +101,7 @@ class GroupShardedStage2(nn.Layer):
self._all_params.extend(list(optim.local_params))
# sharing stage 2 comm overlap flag
self._comm_overlap = False
self._reduce_overlap = False
self._trainable_params = []
self._grad_reduced = []
......@@ -309,17 +309,17 @@ class GroupShardedStage2(nn.Layer):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()
def _set_comm_overlap(self, comm_overlap):
def _set_reduce_overlap(self, reduce_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_comm_overlap(True)
self._comm_overlap = comm_overlap
if self._comm_overlap:
# model._set_reduce_overlap(True)
self._reduce_overlap = reduce_overlap
if self._reduce_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_comm_overlap(comm_overlap)
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)
def _get_reduce_fn(self, index, param, dst_rank):
"""
......@@ -357,7 +357,7 @@ class GroupShardedStage2(nn.Layer):
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group,
sync_op=not self._comm_overlap))
sync_op=not self._reduce_overlap))
# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
......@@ -407,7 +407,7 @@ class GroupShardedStage2(nn.Layer):
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group,
sync_op=not self._comm_overlap))
sync_op=not self._reduce_overlap))
cleanup()
......@@ -545,7 +545,7 @@ class GroupShardedStage2(nn.Layer):
opt_step = opt.step
def _opt_step(self):
if self._comm_overlap:
if self._reduce_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
......
......@@ -92,13 +92,15 @@ def train_mlp(model,
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if sharding_stage == 2:
origin_model = model
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, group=group)
model = GroupShardedStage2(model,
optimizer,
group=group,
buffer_max_size=2**21)
model._set_comm_overlap(True)
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)
else:
model = paddle.DataParallel(model)
......@@ -149,6 +151,8 @@ def train_mlp(model,
optimizer.step()
optimizer.clear_grad()
paddle.device.cuda.synchronize()
if save_model:
return model, optimizer
return model.parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册