未验证 提交 327e5050 编写于 作者: B Baibaifan 提交者: GitHub

Integration sharding stage2 function (#38151)

上级 9e42fe9a
......@@ -16,21 +16,19 @@
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e
import copy
import time
import logging
import numpy as np
from math import inf
from itertools import chain
from functools import reduce
from collections import OrderedDict
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group
from ...utils.internal_storage import ParamStorage
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
......@@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer):
# Feature Notes:
# 1. Unified memory for parameters and parameters.grad to InternalStorage.
# 2. Support the segmentation of optimizer parameters and partial updating of parameters.
# 3. Dynamically adjust training parameters and models
# 3. Dynamically adjust training parameters and models.
# 4. Support offload function.
# 5. Support the establishment of independent communication groups.
# 6. Broadcast_fp16 is not supported now.
def __init__(self,
params,
optim,
group,
group=None,
broadcast_fp16=False,
offload=False,
device="gpu",
......@@ -78,13 +76,16 @@ class ShardingOptimizerStage2(Optimizer):
self._dtype_rank_params = OrderedDict(
) # {dtype:[param1,param2]} device, rank, params
self._param2rank = {}
self._segment_params = []
self.__segment_params = []
self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
self._param2align = {} # {param.name: align}
# Default information
self._optim_defaults = kw
self._optim = optim
self._ori_parameter_list = self._optim._parameter_list
self._ori_param_groups = self._optim._param_groups
assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"
self._local_params = params
......@@ -94,8 +95,8 @@ class ShardingOptimizerStage2(Optimizer):
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0
assert group is not None, "Distributed communication group is must be gived"
self.group = group
group = _get_global_group() if group is None else group
self.world_size = group.nranks
self.rank = group.rank
......@@ -119,7 +120,7 @@ class ShardingOptimizerStage2(Optimizer):
self._master_params = {}
# Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status()
self._update_opt_status()
def _generate_master_params(self, trainable_params):
if self.offload:
......@@ -137,7 +138,7 @@ class ShardingOptimizerStage2(Optimizer):
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
def update_opt_status(self):
def _update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed.
"""
# func 1
......@@ -147,12 +148,12 @@ class ShardingOptimizerStage2(Optimizer):
# Segement helpers
def segment_params(self):
def _segment_params(self):
"""
Divide all optimizer parameters equally into rank.
"""
if len(self._segment_params) == 0:
self._segment_params, param_lists = [
if len(self.__segment_params) == 0:
self.__segment_params, param_lists = [
[] for _ in range(self.world_size)
], [[] for _ in range(self.world_size)]
sizes = [0] * self.world_size
......@@ -165,9 +166,8 @@ class ShardingOptimizerStage2(Optimizer):
sizes[rank] += np.prod(param.shape) if param.trainable else 0
for rank, params in enumerate(param_lists):
# param_group_rank = copy.copy(params)
self._segment_params[rank].extend(params)
return self._segment_params
self.__segment_params[rank].extend(params)
return self.__segment_params
@property
def local_params(self):
......@@ -177,7 +177,7 @@ class ShardingOptimizerStage2(Optimizer):
def param2rank(self):
"""Map the params to the rank which owns them"""
if len(self._param2rank) == 0:
for rank, params in enumerate(self.segment_params()):
for rank, params in enumerate(self._segment_params()):
for param in params:
self._param2rank[param.name] = rank
return self._param2rank
......@@ -271,32 +271,31 @@ class ShardingOptimizerStage2(Optimizer):
"""
if self.offload:
self._optim._parameter_list = [
param for name, param in self._master_params.items()
]
params_list = list(self._master_params.values())
else:
# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
Type.fp32.value][self.rank]
params_list = []
for dtype in self.dtype_rank_params.keys():
params_list.extend(self.dtype_rank_params[dtype][self.rank])
params_name_list = list(map(lambda p: p.name, params_list))
if not isinstance(self._optim._param_groups[0], dict):
self._optim._parameter_list = params_list
self._optim._param_groups = params_list
else:
for param_group in self._optim._param_groups:
p_group = []
for param in param_group['params']:
if param.name in params_name_list:
p_group.append(params_list[params_name_list.index(
param.name)])
param_group['params'] = p_group
# Run the optimizer of the current rank step
if self.offload:
with device_guard(self.rank, self.offload_device):
with device_guard(device=self.offload_device):
self._optim.step()
for param in self._optim._parameter_list:
self._master_params[param.name].set_value(param)
dev_id = 0 if paddle.get_device() == "cpu" else int(
paddle.get_device().split(":")[1])
......@@ -312,10 +311,11 @@ class ShardingOptimizerStage2(Optimizer):
self._broadcast_params()
# Return full parameters to optimizer parameters
self._optim._parameter_list = self._local_params
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups
def clear_cache(self):
self._segment_params.clear()
def _clear_cache(self):
self.__segment_params.clear()
self._dtype_rank_params.clear()
self._param2rank.clear()
......
......@@ -24,10 +24,12 @@ import numpy as np
from itertools import chain
from functools import reduce
from collections import deque
from types import MethodType
import paddle
from paddle import nn
import paddle.distributed as dist
from paddle.distributed.collective import _get_global_group
from ...utils.internal_storage import GradStorage
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
......@@ -57,7 +59,7 @@ class ShardingStage2(nn.Layer):
self,
layer,
sharding_optimizer,
group,
group=None,
sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB
......@@ -83,13 +85,12 @@ class ShardingStage2(nn.Layer):
self._accumulate_grads = accumulate_grads
# Communication related attributes
assert group is not None, "Distributed communication group is must be gived"
self._group = group
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = self._group.rank
group = _get_global_group() if group is None else group
self._world_size_scaling = 1.0 / group.nranks
assert group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
self._default_device = device
# Global statistical parameters
......@@ -112,8 +113,8 @@ class ShardingStage2(nn.Layer):
self._has_grad_storage = []
self._grad_storage_list = []
# offload
# TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy
# Offload
# TODO(haohongxiang): Now it's not be supported for multi-optimizers using Offload strategy
self._offload_optims = list(
filter(lambda optim: optim.offload, self._sharding_optimizers))
if len(self._offload_optims) > 0:
......@@ -134,6 +135,11 @@ class ShardingStage2(nn.Layer):
# Set tasks flow
self._tasks_flow = deque()
# Define optimizer step and clear_grad
if self._accumulate_grads:
self._redefine_opt_step()
self._redefine_opt_clear()
def forward(self, *inputs, **kwargs):
"""
A wrapper for Sharding Stage2 layer.
......@@ -161,7 +167,7 @@ class ShardingStage2(nn.Layer):
return fw
def clear_gradients(self):
def _clear_gradients(self):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
"""
......@@ -176,7 +182,7 @@ class ShardingStage2(nn.Layer):
if param.name in self._param_grads and param.grad is not None:
param.clear_gradient()
def grad_scale(self):
def _grad_scale(self):
"""
Before the gradient accumulation, scale the gradient.
"""
......@@ -287,9 +293,6 @@ class ShardingStage2(nn.Layer):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()
if not self._accumulate_grads:
self._grads_flipped = False
def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
......@@ -412,7 +415,6 @@ class ShardingStage2(nn.Layer):
self._bw_hooks.pop().remove()
# Go through the parameters, attach the hook
self._grad_accs = []
if not self.training:
return
......@@ -500,9 +502,6 @@ class ShardingStage2(nn.Layer):
# Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask
# The whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._bw_hooks) > 0
if trainability_changed:
logging.warning(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
......@@ -548,3 +547,25 @@ class ShardingStage2(nn.Layer):
format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2
**18))
return rank_buffer_size
def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale
for opt in self._sharding_optimizers:
opt_step = opt.step
def _opt_step(self):
grad_func()
opt_step()
opt.step = MethodType(_opt_step, opt)
def _redefine_opt_clear(self):
clear_func = self._clear_gradients
def _opt_clear(self):
clear_func()
for opt in self._sharding_optimizers:
opt.clear_grad = MethodType(_opt_clear, opt)
......@@ -131,7 +131,7 @@ class ShardingClipGrad:
@contextlib.contextmanager
def device_guard(dev_id, device="cpu"):
def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device()
if device == "cpu":
paddle.set_device(device)
......
......@@ -125,7 +125,7 @@ def train_mlp():
oss_optimizer.step()
# oss_optimizer clear cache
oss_optimizer.clear_cache()
oss_optimizer._clear_cache()
if __name__ == '__main__':
......
......@@ -30,7 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar
seed = 2021
epoch = 2
batch_size = 32
linear_size = 10000
linear_size = 1000
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
......@@ -46,7 +46,7 @@ paddle.seed(seed)
class MLP(fluid.Layer):
def __init__(self, linear_size=10000, param_attr=None, bias_attr=None):
def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()
self._linear1 = Linear(linear_size, linear_size)
......@@ -60,7 +60,7 @@ class MLP(fluid.Layer):
return y
def reader_decorator(linear_size=10000):
def reader_decorator(linear_size=1000):
def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
......@@ -70,10 +70,12 @@ def reader_decorator(linear_size=10000):
return __reader__
def optimizer_setting(model, use_pure_fp16):
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
parameters=[{
"params": model.parameters()
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
......@@ -85,27 +87,32 @@ def optimizer_setting(model, use_pure_fp16):
def train_mlp(model,
sharding_stage,
use_pure_fp16=False,
all_test=False,
accumulate_grad=False):
accumulate_grad=False,
opt_group=False):
if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group()
else:
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
if opt_group:
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
params=model.parameters(), optim=optimizer, group=group)
if all_test:
if accumulate_grad:
model = ShardingStage2(
model, optimizer, group=group, accumulate_grads=accumulate_grad)
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=accumulate_grad)
else:
model = ShardingStage2(model, optimizer, group=group)
model = ShardingStage2(
model, optimizer, group=group, buffer_max_size=2**21)
else:
optimizer = fleet.distributed_optimizer(optimizer)
model = fleet.distributed_model(model)
......@@ -132,29 +139,16 @@ def train_mlp(model,
label.stop_gradient = True
img.stop_gradient = True
with paddle.amp.auto_cast(enable=use_pure_fp16, level='O2'):
out = model(img)
loss = paddle.nn.functional.cross_entropy(
input=out, label=label)
out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
avg_loss.backward()
if accumulate_grad and batch_id == 2:
model.grad_scale()
optimizer.step()
model.clear_gradients()
return model.parameters()
if not accumulate_grad:
optimizer.step()
if sharding_stage == 2:
model.clear_gradients()
else:
optimizer.clear_grad()
optimizer.step()
optimizer.clear_grad()
if all_test and batch_id == 2:
if accumulate_grad and batch_id == 2:
return model.parameters()
return model.parameters()
......@@ -171,22 +165,19 @@ def test_dp_stage2():
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False)
stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False)
dp_params = train_mlp(
mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=True)
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
for j in range(len(stage2_params)):
if dp_params[i].name == stage2_params[j].name:
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
stage2_params = train_mlp(
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
stage2_params = train_mlp(mlp3, sharding_stage=2)
stage2_accumulate_grad = train_mlp(
mlp4,
sharding_stage=2,
use_pure_fp16=True,
all_test=True,
accumulate_grad=True)
mlp4, sharding_stage=2, accumulate_grad=True)
for i in range(len(stage2_params)):
for j in range(len(stage2_accumulate_grad)):
if stage2_params[i].name == stage2_accumulate_grad[j].name:
......
......@@ -33,7 +33,7 @@ from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting
seed = 2021
epoch = 2
batch_size = 32
linear_size = 8000
linear_size = 1000
np.random.seed(seed)
paddle.seed(seed)
......@@ -52,7 +52,12 @@ def train_mlp(model, offload=False):
optim=optimizer,
group=group,
offload=offload)
model = ShardingStage2(model, optimizer, group=group, accumulate_grads=True)
model = ShardingStage2(
model,
optimizer,
group=group,
buffer_max_size=2**21,
accumulate_grads=True)
train_reader = paddle.batch(
reader_decorator(linear_size), batch_size=batch_size, drop_last=True)
......@@ -81,10 +86,9 @@ def train_mlp(model, offload=False):
avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
scaler.scale(avg_loss).backward()
model.grad_scale()
scaler.step(optimizer)
scaler.update()
model.clear_gradients()
optimizer.clear_grad()
for dtype in optimizer.param_storages:
for dst_rank, param_storage in optimizer.param_storages[dtype].items():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册