未验证 提交 5b642140 编写于 作者: Y Yuang Liu 提交者: GitHub

Cherry pick for sharding (#47061)

* [dygraph sharding] Overlap the reduce and the caculation for sharding stage 2. (#46495)

* [dygraph sharding stage 2] sharding broadcast overlap (#46656)

* Multi groups for broadcast of sharding stage 2 (#46894)
上级 b84edd90
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
import copy import copy
import logging import logging
import warnings
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
...@@ -86,6 +88,11 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -86,6 +88,11 @@ class GroupShardedOptimizerStage2(Optimizer):
# Default information # Default information
self._optim = optim self._optim = optim
# sharing stage 2 comm overlap flag
self._reduce_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None
assert hasattr(self._optim, "_master_weights" assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute" ), "Must use optimizer with _master_weights attribute"
...@@ -103,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -103,6 +110,17 @@ class GroupShardedOptimizerStage2(Optimizer):
filter(lambda x: x.trainable and x.dtype == Type.fp16.value, filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0 self._local_params))) > 0
self._broadcast_overlap = False
self._forward_pre_hook_remove_helper = []
try:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self._broadcast_order_params = sorted(
self.local_params,
key=lambda x: int(x.name.split('.')[0].split('_')[-1]))
except ValueError:
self._broadcast_order_params = None
self._group = new_group( self._group = new_group(
_get_global_group().ranks) if group is None else group _get_global_group().ranks) if group is None else group
...@@ -157,6 +175,60 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -157,6 +175,60 @@ class GroupShardedOptimizerStage2(Optimizer):
group=self._group, group=self._group,
sync_op=True) sync_op=True)
def _update_task(self, task):
if self._reduce_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task
def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap
def _set_broadcast_overlap(self,
broadcast_overlap,
layers=None,
num_groups=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
assert layers is not None, \
"To enable broadcast overlap forward, please pass the module to the function."
self._layers = layers
warnings.warn(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if self._broadcast_order_params is None:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings.warn(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params
if num_groups is None or num_groups > len(self._broadcast_order_params):
warnings.warn(
"The num_groups for broadcast is larger than the number of params to be broadcast. "
"It will set to default value: 1 (use the default sharding group)."
)
num_groups = 1
assert isinstance(
num_groups,
int) and num_groups > 0, "num_groups should be a positive integer"
self._number_of_broadcast_groups = num_groups
self._broadcast_groups = [
None for _ in range(self._number_of_broadcast_groups)
]
self._broadcast_groups[0] = self._group
ranks = self._group.ranks
for i in range(1, self._number_of_broadcast_groups):
self._broadcast_groups[i] = new_group(ranks)
def _generate_master_params(self, trainable_params): def _generate_master_params(self, trainable_params):
if self.offload: if self.offload:
for param in trainable_params: for param in trainable_params:
...@@ -364,6 +436,13 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -364,6 +436,13 @@ class GroupShardedOptimizerStage2(Optimizer):
""" """
A wrapper for Optimizer's step function to finish the update operation of the optimizer. A wrapper for Optimizer's step function to finish the update operation of the optimizer.
""" """
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
hook_remove.remove()
self._forward_pre_hook_remove_helper = []
if self.offload: if self.offload:
params_list = [self.offload_params.buffer] params_list = [self.offload_params.buffer]
...@@ -408,9 +487,52 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -408,9 +487,52 @@ class GroupShardedOptimizerStage2(Optimizer):
"""Broadcast the parameters of the current rank to each rank""" """Broadcast the parameters of the current rank to each rank"""
# Exchange all the shards with the other ranks # Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values(): if self._broadcast_overlap:
for dst_rank, internal_storage in dtype_per_rank.items(): self._broadcast_params_overlap_forward()
broadcast(tensor=internal_storage.buffer, else:
src=self._group.ranks[dst_rank], for dtype_per_rank in self.param_storages.values():
group=self._group, for dst_rank, internal_storage in dtype_per_rank.items():
sync_op=True) broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)
def _forward_pre_hook_function(self, tasks):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def __impl__(x, y):
for task in tasks:
# Wait for broadcast task before using the result of the broadcast.
task.wait()
return __impl__
@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
group_idx = 0
param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
group = self._broadcast_groups[group_idx]
group_idx = (group_idx + 1) % self._number_of_broadcast_groups
task = broadcast(tensor=x,
src=group.ranks[self._param2rank[x.name]],
group=group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task
for layer in self._layers.sublayers():
if len(layer.sublayers()) == 0:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks = []
for param in layer.parameters():
if param.trainable:
if param.name in param2task:
tasks.append(param2task[param.name])
self._forward_pre_hook_remove_helper.append(
layer.register_forward_pre_hook(
self._forward_pre_hook_function(tasks)))
...@@ -100,6 +100,9 @@ class GroupShardedStage2(nn.Layer): ...@@ -100,6 +100,9 @@ class GroupShardedStage2(nn.Layer):
for optim in self._sharding_optimizers: for optim in self._sharding_optimizers:
self._all_params.extend(list(optim.local_params)) self._all_params.extend(list(optim.local_params))
# sharing stage 2 comm overlap flag
self._reduce_overlap = False
self._trainable_params = [] self._trainable_params = []
self._grad_reduced = [] self._grad_reduced = []
self._trainable_param2rank = {} self._trainable_param2rank = {}
...@@ -306,6 +309,18 @@ class GroupShardedStage2(nn.Layer): ...@@ -306,6 +309,18 @@ class GroupShardedStage2(nn.Layer):
for grad_storage in self._grad_storage_list: for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in() grad_storage.reset_checked_in()
def _set_reduce_overlap(self, reduce_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_reduce_overlap(True)
self._reduce_overlap = reduce_overlap
if self._reduce_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)
def _get_reduce_fn(self, index, param, dst_rank): def _get_reduce_fn(self, index, param, dst_rank):
""" """
There are two ways to reduce gradient. There are two ways to reduce gradient.
...@@ -337,11 +352,12 @@ class GroupShardedStage2(nn.Layer): ...@@ -337,11 +352,12 @@ class GroupShardedStage2(nn.Layer):
del tmp_grad del tmp_grad
param.clear_gradient(False) param.clear_gradient(False)
# Synchronize the reduce parameter gradient # Synchronize the reduce parameter gradient asynchronize
collective.reduce(tensor=param.grad, self._sharding_optimizers[0]._update_task(
dst=self._group.ranks[dst_rank], collective.reduce(tensor=param.grad,
group=self._group) dst=self._group.ranks[dst_rank],
# TODO (Baibaifan) Asynchronous the reduce parameter gradient group=self._group,
sync_op=not self._reduce_overlap))
# Clear the task flow and trigger callback to clear the redundant gradient # Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow() # self._clear_task_flow()
...@@ -385,12 +401,13 @@ class GroupShardedStage2(nn.Layer): ...@@ -385,12 +401,13 @@ class GroupShardedStage2(nn.Layer):
# Reduce the bucket # Reduce the bucket
grad_storage.sent = True grad_storage.sent = True
# Synchronize the reduce parameter gradient # Synchronize the reduce parameter gradient asynchronize
collective.reduce( self._sharding_optimizers[0]._update_task(
tensor=grad_storage.buffer, collective.reduce(
dst=self._group.ranks[grad_storage.destination], tensor=grad_storage.buffer,
group=self._group) dst=self._group.ranks[grad_storage.destination],
# TODO (Baibaifan) Asynchronous the reduce parameter gradient group=self._group,
sync_op=not self._reduce_overlap))
cleanup() cleanup()
...@@ -528,6 +545,10 @@ class GroupShardedStage2(nn.Layer): ...@@ -528,6 +545,10 @@ class GroupShardedStage2(nn.Layer):
opt_step = opt.step opt_step = opt.step
def _opt_step(self): def _opt_step(self):
if self._reduce_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
grad_func() grad_func()
opt_step() opt_step()
......
# -*- coding: UTF-8 -*-
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import numpy as np
import argparse
import tempfile
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.fluid.framework import _test_eager_guard
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2
seed = 2022
epoch = 2
linear_size = 1000
np.random.seed(seed)
paddle.seed(seed)
class MLP(fluid.Layer):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)
def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label
return __reader__
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(parameters=[{
"params": model.parameters(),
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)
return optimizer
def train_mlp(model,
sharding_stage,
batch_size=100,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
save_model=False,
test_minimize=False):
if sharding_stage != "dp":
group = paddle.distributed.new_group([0, 1], backend="nccl")
if opt_group:
optimizer = optimizer_setting(model=model,
use_pure_fp16=use_pure_fp16,
opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if sharding_stage == 2:
origin_model = model
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, group=group)
model = GroupShardedStage2(model,
optimizer,
group=group,
buffer_max_size=2**21)
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)
else:
model = paddle.DataParallel(model)
# check optimizer.minimize() error
if test_minimize:
try:
optimizer.minimize()
except:
print(
"====== Find sharding_stage2_optimizer.minimize() error ======")
return
train_reader = paddle.batch(reader_decorator(),
batch_size=batch_size,
drop_last=True)
train_loader = paddle.io.DataLoader.from_generator(capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)
if sharding_stage == 2:
model.to(device="gpu")
for eop in range(epoch):
model.train()
for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward()
if not accumulate_grad:
optimizer.step()
optimizer.clear_grad()
if accumulate_grad:
optimizer.step()
optimizer.clear_grad()
paddle.device.cuda.synchronize()
if save_model:
return model, optimizer
return model.parameters()
def test_dp_stage2():
paddle.distributed.init_parallel_env()
mlp = MLP()
state_dict = mlp.state_dict()
mlp1 = MLP()
mlp2 = MLP()
mlp3 = MLP()
mlp4 = MLP()
mlp5 = MLP()
mlp6 = MLP()
mlp7 = MLP()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)
# DP VS stage2
dp_params = train_mlp(mlp1,
sharding_stage="dp",
use_pure_fp16=False,
opt_group=False)
stage2_params = train_mlp(mlp2,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False)
for i in range(len(dp_params)):
np.testing.assert_allclose(dp_params[i].numpy(),
stage2_params[i].numpy(),
rtol=1e-6)
# stage2 accumulate grad
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
stage2_accumulate_grad = train_mlp(mlp4,
sharding_stage=2,
batch_size=20,
accumulate_grad=True)
for i in range(len(stage2_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
stage2_accumulate_grad[i].numpy(),
rtol=1e-5,
atol=1e-5)
# stage2 param list VS param group
stage2_params = train_mlp(mlp5,
sharding_stage=2,
use_pure_fp16=False,
opt_group=True)
for i in range(len(dp_params)):
np.testing.assert_allclose(dp_params[i].numpy(),
stage2_params[i].numpy(),
rtol=1e-6)
# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage2, optimizer_stage2 = train_mlp(mlp6,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage2.state_dict(), model_file)
paddle.save(optimizer_stage2.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage2.set_state_dict(m_state_dict)
optimizer_stage2.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)
# check optimizer.minimize() error
train_mlp(mlp7, sharding_stage=2, test_minimize=True)
return
if __name__ == '__main__':
with _test_eager_guard():
test_dp_stage2()
...@@ -33,6 +33,9 @@ class TestDygraphShardingStage2(TestMultipleGpus): ...@@ -33,6 +33,9 @@ class TestDygraphShardingStage2(TestMultipleGpus):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py', self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py',
eager_mode=False) eager_mode=False)
def test_dygraph_sharding_stage2_with_comm_overlap(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage2_comm_overlap.py')
if __name__ == "__main__": if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1" os.environ["FLAGS_enable_eager_mode"] = "1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册