From 327e5050eb8bdf2548979c39e2a6637beaecadac Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Sun, 19 Dec 2021 20:34:18 +0800 Subject: [PATCH] Integration sharding stage2 function (#38151) --- .../sharding_optimizer_stage2.py | 76 +++++++++---------- .../meta_parallel/sharding/sharding_stage2.py | 55 +++++++++----- .../meta_parallel/sharding/sharding_utils.py | 2 +- .../dygraph_sharding_optimizer_stage2.py | 2 +- .../unittests/dygraph_sharding_stage2.py | 75 ++++++++---------- .../dygraph_sharding_stage2_offload.py | 12 ++- 6 files changed, 119 insertions(+), 103 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index dc313c33ee3..663b2293b45 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -16,21 +16,19 @@ #Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e import copy -import time import logging import numpy as np -from math import inf from itertools import chain from functools import reduce from collections import OrderedDict import paddle import paddle.fluid as fluid -from paddle import framework from paddle.fluid import core import paddle.distributed as dist from paddle.optimizer import Optimizer from paddle.fluid.clip import ClipGradByGlobalNorm +from paddle.distributed.collective import _get_global_group from ...utils.internal_storage import ParamStorage from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad @@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer): # Feature Notes: # 1. Unified memory for parameters and parameters.grad to InternalStorage. # 2. Support the segmentation of optimizer parameters and partial updating of parameters. - # 3. Dynamically adjust training parameters and models。 + # 3. Dynamically adjust training parameters and models. # 4. Support offload function. # 5. Support the establishment of independent communication groups. # 6. Broadcast_fp16 is not supported now. def __init__(self, params, optim, - group, + group=None, broadcast_fp16=False, offload=False, device="gpu", @@ -78,13 +76,16 @@ class ShardingOptimizerStage2(Optimizer): self._dtype_rank_params = OrderedDict( ) # {dtype:[param1,param2]} device, rank, params self._param2rank = {} - self._segment_params = [] + self.__segment_params = [] self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}} self._param2align = {} # {param.name: align} # Default information self._optim_defaults = kw self._optim = optim + self._ori_parameter_list = self._optim._parameter_list + self._ori_param_groups = self._optim._param_groups + assert hasattr(self._optim, "_master_weights" ), "Must use optimizer with _master_weights attribute" self._local_params = params @@ -94,8 +95,8 @@ class ShardingOptimizerStage2(Optimizer): filter(lambda x: x.trainable and x.dtype == Type.fp16.value, self._local_params))) > 0 - assert group is not None, "Distributed communication group is must be gived" self.group = group + group = _get_global_group() if group is None else group self.world_size = group.nranks self.rank = group.rank @@ -119,7 +120,7 @@ class ShardingOptimizerStage2(Optimizer): self._master_params = {} # Update optimizer parameters and adjust parameter storage and use according to rank. - self.update_opt_status() + self._update_opt_status() def _generate_master_params(self, trainable_params): if self.offload: @@ -137,7 +138,7 @@ class ShardingOptimizerStage2(Optimizer): self._optim._master_weights[param.name] = paddle.cast( param, Type.fp32.value) - def update_opt_status(self): + def _update_opt_status(self): """Update optimizer status and parameter storage information, and special functions to be developed. """ # func 1 @@ -147,12 +148,12 @@ class ShardingOptimizerStage2(Optimizer): # Segement helpers - def segment_params(self): + def _segment_params(self): """ Divide all optimizer parameters equally into rank. """ - if len(self._segment_params) == 0: - self._segment_params, param_lists = [ + if len(self.__segment_params) == 0: + self.__segment_params, param_lists = [ [] for _ in range(self.world_size) ], [[] for _ in range(self.world_size)] sizes = [0] * self.world_size @@ -165,9 +166,8 @@ class ShardingOptimizerStage2(Optimizer): sizes[rank] += np.prod(param.shape) if param.trainable else 0 for rank, params in enumerate(param_lists): - # param_group_rank = copy.copy(params) - self._segment_params[rank].extend(params) - return self._segment_params + self.__segment_params[rank].extend(params) + return self.__segment_params @property def local_params(self): @@ -177,7 +177,7 @@ class ShardingOptimizerStage2(Optimizer): def param2rank(self): """Map the params to the rank which owns them""" if len(self._param2rank) == 0: - for rank, params in enumerate(self.segment_params()): + for rank, params in enumerate(self._segment_params()): for param in params: self._param2rank[param.name] = rank return self._param2rank @@ -271,32 +271,31 @@ class ShardingOptimizerStage2(Optimizer): """ if self.offload: - self._optim._parameter_list = [ - param for name, param in self._master_params.items() - ] + params_list = list(self._master_params.values()) else: # Synchronize optimizer parameters for the current rank - if len(self.dtype_rank_params.keys( - )) == 1 and Type.fp32.value in self.dtype_rank_params.keys(): - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp32.value][self.rank] - elif len(self.dtype_rank_params.keys( - )) == 1 and Type.fp16.value in self.dtype_rank_params.keys(): - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp16.value][self.rank] - else: - self._optim._parameter_list = self.dtype_rank_params[ - Type.fp16.value][self.rank] + self.dtype_rank_params[ - Type.fp32.value][self.rank] + params_list = [] + for dtype in self.dtype_rank_params.keys(): + params_list.extend(self.dtype_rank_params[dtype][self.rank]) + + params_name_list = list(map(lambda p: p.name, params_list)) + if not isinstance(self._optim._param_groups[0], dict): + self._optim._parameter_list = params_list + self._optim._param_groups = params_list + else: + for param_group in self._optim._param_groups: + p_group = [] + for param in param_group['params']: + if param.name in params_name_list: + p_group.append(params_list[params_name_list.index( + param.name)]) + param_group['params'] = p_group # Run the optimizer of the current rank step if self.offload: - with device_guard(self.rank, self.offload_device): + with device_guard(device=self.offload_device): self._optim.step() - for param in self._optim._parameter_list: - self._master_params[param.name].set_value(param) - dev_id = 0 if paddle.get_device() == "cpu" else int( paddle.get_device().split(":")[1]) @@ -312,10 +311,11 @@ class ShardingOptimizerStage2(Optimizer): self._broadcast_params() # Return full parameters to optimizer parameters - self._optim._parameter_list = self._local_params + self._optim._parameter_list = self._ori_parameter_list + self._optim._param_groups = self._ori_param_groups - def clear_cache(self): - self._segment_params.clear() + def _clear_cache(self): + self.__segment_params.clear() self._dtype_rank_params.clear() self._param2rank.clear() diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index fd49c2a7d65..1a381385a89 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -24,10 +24,12 @@ import numpy as np from itertools import chain from functools import reduce from collections import deque +from types import MethodType import paddle from paddle import nn import paddle.distributed as dist +from paddle.distributed.collective import _get_global_group from ...utils.internal_storage import GradStorage from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 @@ -57,7 +59,7 @@ class ShardingStage2(nn.Layer): self, layer, sharding_optimizer, - group, + group=None, sync_buffers=False, pertrain_sync_models=True, buffer_max_size=2**23, #8MB @@ -83,13 +85,12 @@ class ShardingStage2(nn.Layer): self._accumulate_grads = accumulate_grads # Communication related attributes - assert group is not None, "Distributed communication group is must be gived" self._group = group - self._world_size_scaling = 1.0 / self._group.nranks - assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1" - self._rank = self._group.rank + group = _get_global_group() if group is None else group + self._world_size_scaling = 1.0 / group.nranks + assert group.nranks > 1, "Training must be distributed, ranks must be greater than 1" + self._rank = group.rank self._global_root_rank = 0 # picking rank 0 as the reference - self._global_ranks = self._group.ranks self._default_device = device # Global statistical parameters @@ -112,8 +113,8 @@ class ShardingStage2(nn.Layer): self._has_grad_storage = [] self._grad_storage_list = [] - # offload - # TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy + # Offload + # TODO(haohongxiang): Now it's not be supported for multi-optimizers using Offload strategy self._offload_optims = list( filter(lambda optim: optim.offload, self._sharding_optimizers)) if len(self._offload_optims) > 0: @@ -134,6 +135,11 @@ class ShardingStage2(nn.Layer): # Set tasks flow self._tasks_flow = deque() + # Define optimizer step and clear_grad + if self._accumulate_grads: + self._redefine_opt_step() + self._redefine_opt_clear() + def forward(self, *inputs, **kwargs): """ A wrapper for Sharding Stage2 layer. @@ -161,7 +167,7 @@ class ShardingStage2(nn.Layer): return fw - def clear_gradients(self): + def _clear_gradients(self): """ Set zero to the gradient of the optimizer's current rank trainable parameters. """ @@ -176,7 +182,7 @@ class ShardingStage2(nn.Layer): if param.name in self._param_grads and param.grad is not None: param.clear_gradient() - def grad_scale(self): + def _grad_scale(self): """ Before the gradient accumulation, scale the gradient. """ @@ -287,9 +293,6 @@ class ShardingStage2(nn.Layer): for grad_storage in self._grad_storage_list: grad_storage.reset_checked_in() - if not self._accumulate_grads: - self._grads_flipped = False - def _get_reduce_fn(self, index, param, dst_rank): """ There are two ways to reduce gradient. @@ -412,7 +415,6 @@ class ShardingStage2(nn.Layer): self._bw_hooks.pop().remove() # Go through the parameters, attach the hook - self._grad_accs = [] if not self.training: return @@ -500,9 +502,6 @@ class ShardingStage2(nn.Layer): # Whether parameters trainability changed trainability_changed = trainable_mask != self._trainable_mask - # The whole model is not trainable but we still have grad hooks - trainability_changed |= not self.training and len(self._bw_hooks) > 0 - if trainability_changed: logging.warning( "Trainable params changed, because of eval/train mode or parameter freezing/unfreeze." @@ -548,3 +547,25 @@ class ShardingStage2(nn.Layer): format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2 **18)) return rank_buffer_size + + def _redefine_opt_step(self): + if not self._accumulate_grads: + return + grad_func = self._grad_scale + for opt in self._sharding_optimizers: + opt_step = opt.step + + def _opt_step(self): + grad_func() + opt_step() + + opt.step = MethodType(_opt_step, opt) + + def _redefine_opt_clear(self): + clear_func = self._clear_gradients + + def _opt_clear(self): + clear_func() + + for opt in self._sharding_optimizers: + opt.clear_grad = MethodType(_opt_clear, opt) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index 651bed82396..b080035a116 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -131,7 +131,7 @@ class ShardingClipGrad: @contextlib.contextmanager -def device_guard(dev_id, device="cpu"): +def device_guard(dev_id=0, device="cpu"): origin_device = paddle.device.get_device() if device == "cpu": paddle.set_device(device) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py index 571e41b2c4f..6a9005b8ce6 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_optimizer_stage2.py @@ -125,7 +125,7 @@ def train_mlp(): oss_optimizer.step() # oss_optimizer clear cache - oss_optimizer.clear_cache() + oss_optimizer._clear_cache() if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index 2b4002ab9c9..e08b4db1e98 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -30,7 +30,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import Shar seed = 2021 epoch = 2 batch_size = 32 -linear_size = 10000 +linear_size = 1000 strategy = fleet.DistributedStrategy() strategy.hybrid_configs = { @@ -46,7 +46,7 @@ paddle.seed(seed) class MLP(fluid.Layer): - def __init__(self, linear_size=10000, param_attr=None, bias_attr=None): + def __init__(self, linear_size=1000, param_attr=None, bias_attr=None): super(MLP, self).__init__() self._linear1 = Linear(linear_size, linear_size) @@ -60,7 +60,7 @@ class MLP(fluid.Layer): return y -def reader_decorator(linear_size=10000): +def reader_decorator(linear_size=1000): def __reader__(): for _ in range(100): img = np.random.rand(linear_size).astype('float32') @@ -70,10 +70,12 @@ def reader_decorator(linear_size=10000): return __reader__ -def optimizer_setting(model, use_pure_fp16): +def optimizer_setting(model, use_pure_fp16, opt_group=False): clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), + parameters=[{ + "params": model.parameters() + }] if opt_group else model.parameters(), learning_rate=0.001, weight_decay=0.00001, grad_clip=clip, @@ -85,27 +87,32 @@ def optimizer_setting(model, use_pure_fp16): def train_mlp(model, sharding_stage, use_pure_fp16=False, - all_test=False, - accumulate_grad=False): + accumulate_grad=False, + opt_group=False): if sharding_stage == "dp": hcg = fleet.get_hybrid_communicate_group() group = hcg.get_check_parallel_group() else: group = paddle.distributed.new_group([0, 1]) - optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) - - if use_pure_fp16: - model = paddle.amp.decorate( - models=model, level='O2', save_dtype='float32') + if opt_group: + optimizer = optimizer_setting( + model=model, use_pure_fp16=use_pure_fp16, opt_group=opt_group) + else: + optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16) if sharding_stage == 2: optimizer = ShardingOptimizerStage2( params=model.parameters(), optim=optimizer, group=group) - if all_test: + if accumulate_grad: model = ShardingStage2( - model, optimizer, group=group, accumulate_grads=accumulate_grad) + model, + optimizer, + group=group, + buffer_max_size=2**21, + accumulate_grads=accumulate_grad) else: - model = ShardingStage2(model, optimizer, group=group) + model = ShardingStage2( + model, optimizer, group=group, buffer_max_size=2**21) else: optimizer = fleet.distributed_optimizer(optimizer) model = fleet.distributed_model(model) @@ -132,29 +139,16 @@ def train_mlp(model, label.stop_gradient = True img.stop_gradient = True - with paddle.amp.auto_cast(enable=use_pure_fp16, level='O2'): - out = model(img) - loss = paddle.nn.functional.cross_entropy( - input=out, label=label) + out = model(img) + loss = paddle.nn.functional.cross_entropy(input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) avg_loss.backward() - if accumulate_grad and batch_id == 2: - model.grad_scale() - optimizer.step() - model.clear_gradients() - return model.parameters() - - if not accumulate_grad: - optimizer.step() - - if sharding_stage == 2: - model.clear_gradients() - else: - optimizer.clear_grad() + optimizer.step() + optimizer.clear_grad() - if all_test and batch_id == 2: + if accumulate_grad and batch_id == 2: return model.parameters() return model.parameters() @@ -171,22 +165,19 @@ def test_dp_stage2(): mlp2.set_state_dict(state_dict) mlp3.set_state_dict(state_dict) mlp4.set_state_dict(state_dict) - dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False) - stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False) + dp_params = train_mlp( + mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=True) + stage2_params = train_mlp( + mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) for i in range(len(dp_params)): for j in range(len(stage2_params)): if dp_params[i].name == stage2_params[j].name: np.testing.assert_allclose( dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6) - stage2_params = train_mlp( - mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True) + stage2_params = train_mlp(mlp3, sharding_stage=2) stage2_accumulate_grad = train_mlp( - mlp4, - sharding_stage=2, - use_pure_fp16=True, - all_test=True, - accumulate_grad=True) + mlp4, sharding_stage=2, accumulate_grad=True) for i in range(len(stage2_params)): for j in range(len(stage2_accumulate_grad)): if stage2_params[i].name == stage2_accumulate_grad[j].name: diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py index 8adcda9d24e..37537019c0a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py @@ -33,7 +33,7 @@ from dygraph_sharding_stage2 import MLP, reader_decorator, optimizer_setting seed = 2021 epoch = 2 batch_size = 32 -linear_size = 8000 +linear_size = 1000 np.random.seed(seed) paddle.seed(seed) @@ -52,7 +52,12 @@ def train_mlp(model, offload=False): optim=optimizer, group=group, offload=offload) - model = ShardingStage2(model, optimizer, group=group, accumulate_grads=True) + model = ShardingStage2( + model, + optimizer, + group=group, + buffer_max_size=2**21, + accumulate_grads=True) train_reader = paddle.batch( reader_decorator(linear_size), batch_size=batch_size, drop_last=True) @@ -81,10 +86,9 @@ def train_mlp(model, offload=False): avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) scaler.scale(avg_loss).backward() - model.grad_scale() scaler.step(optimizer) scaler.update() - model.clear_gradients() + optimizer.clear_grad() for dtype in optimizer.param_storages: for dst_rank, param_storage in optimizer.param_storages[dtype].items(): -- GitLab