diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 9248f4699a4fe730823033ef9c809fc7840cfced..956432bb3076c685ddb3b339d19f7ec72ed7e503 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3096,7 +3096,7 @@ void StackInferMeta(const std::vector& x, rank, axis)); if (axis < 0) axis += (rank + 1); - auto vec = phi::vectorize(out_dim); + auto vec = phi::vectorize(out_dim); vec.insert(vec.begin() + axis, input_dims.size()); out->set_dims(phi::make_ddim(vec)); out->set_dtype(x.at(0)->dtype()); diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index c31b0bbef8e12ed2b63ffeb2514b11353c7206ab..a8873b137f3d08ea19cfdda1f65f715b71bfcff7 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -477,7 +477,7 @@ def _set_multi_precision(optimizer, multi_precision): ) optimizer = ( - optimizer._inner_optimizer + optimizer._inner_opt if isinstance(optimizer, DygraphShardingOptimizer) else optimizer ) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 7def3ec4795aa58b9a628d42c3ae11fba231df07..0272fdd086d0d040b6126c64a34374ffcfd11b57 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -187,6 +187,11 @@ class HybridCommunicateGroup: "data" ) + ( + self.sharding_check_group, + self.sharding_check_comm_group, + ) = self._set_check_group("sharding") + # create p2p group self.is_first_stage = self.stage_id == 0 self.is_last_stage = self.stage_id == (self._pp_degree - 1) @@ -428,8 +433,11 @@ class HybridCommunicateGroup: return self._sharding_comm_group.ranks[0] # check parallel group - def get_check_parallel_group(self): - return self._check_comm_group + def get_check_parallel_group(self, sharding=False): + if sharding: + return self.sharding_check_comm_group + else: + return self._check_comm_group def get_rank_from_stage(self, stage_id, **kwargs): return self._topo.get_rank_from_stage( diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index e5db34d4b237e4c642ddadd89fa96d1e3cb2720f..a26d56b2dc1936310258b353d7203d94f1a1548e 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -43,53 +43,49 @@ class DygraphShardingOptimizer: # 3. dynamic trainable params, which is the case bewteen pretraining and finetuning # 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm - def __init__( - self, - hcg, - user_defined_strategy, - params, - inner_optimizer_class, - **inner_optimizer_kargs - ): - if not isinstance(params, list): + def __init__(self, optimizer, hcg): + # TODO(pangengzheng): support param_groups + if isinstance(optimizer._parameter_list[0], dict): raise TypeError( - "`parameters` argument given to the DygraphShardingOptimizer should be " - "an iterable of paddle Tensors, but got argument type is `{}`.".format( - type(params) - ) + "Do not support param_groups now, please set optimizer._parameter_list as a list of Parameter" ) - self._parameter_list = params - self._reference_is_trainable_params = list( - map(_is_trainable, self._parameter_list) - ) - - self._inner_optimizer_class = inner_optimizer_class - self._inner_optimizer_kargs = inner_optimizer_kargs - - # sharding parallel information - # TODO better way to get the hcg & user_defined_strategy + if not hasattr(optimizer, '_apply_optimize') or not callable( + optimizer._apply_optimize + ): + raise ValueError( + "the optimzier object should have _apply_optimize function" + ) + # the self._parameter_list holds the whole model paramters + self._parameter_list = optimizer._parameter_list + self._inner_opt = optimizer self._hcg = hcg - self._user_defined_strategy = user_defined_strategy self._sharding_world_size = self._hcg.get_sharding_parallel_world_size() self._sharding_rank = self._hcg.get_sharding_parallel_rank() - # logic partitioning - self._build_sharding_mapping() + self._rank2params = self._partition_parameters() + self._param2rank = self._map_param_to_rank() - # actually create opt ops - self._buid_inner_optimizer() + self._set_inner_opt_attr( + '_parameter_list', self._rank2params[self._sharding_rank] + ) + self._set_inner_opt_attr( + '_param_groups', self._rank2params[self._sharding_rank] + ) - def clear_grad(self): + def clear_grad(self, set_to_zero=True): """ should clear grad for all parameters in model """ for p in self._parameter_list: - if not p.stop_gradient: - p.clear_gradient() - - def _build_sharding_mapping(self): - self._rank2params = self._partition_parameters() - self._param2rank = self._map_param_to_rank() + if hasattr(p, "main_grad") and p.main_grad is not None: + assert p._grad_ivar() is None + if set_to_zero: + p.main_grad.zero_() + else: + p.main_grad._clear() + p.main_grad = None + elif not hasattr(p, "main_grad"): + p.clear_gradient(set_to_zero) def _partition_parameters(self): """ @@ -132,14 +128,35 @@ class DygraphShardingOptimizer: mapping[param.name] = rank return mapping - def _buid_inner_optimizer(self): - # we rely on the inner opt to determine whether a parameter is stop_gradient or not: - # create moment - # update related ops: clip, regular, opt - self._inner_optimizer = self._inner_optimizer_class( - parameters=self._rank2params[self._sharding_rank], - **self._inner_optimizer_kargs - ) + def reduce_gradients(self, parameter_list, hcg): + # TODO merge grad / nrank with dp + logger.debug("sharding start gradients sync") + with framework.no_grad(): + sharding_nrank = hcg.get_sharding_parallel_group().nranks + for param in parameter_list: + g_var = None + if param.trainable and (param._grad_ivar() is not None): + g_var = param._grad_ivar() + if param.trainable and hasattr(param, "main_grad"): + assert ( + param._grad_ivar() is None + ), "param.grad should be None when using main_grad" + g_var = param.main_grad + if g_var is not None: + g_var.scale_(1.0 / sharding_nrank) + param_rank = self._param2rank[param.name] + paddle.distributed.all_reduce( + g_var, + group=hcg.get_sharding_parallel_group(), + sync_op=True, + ) + # TODO(pangengzheng): change to reduce operation when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. + # paddle.distributed.reduce( + # g_var, + # dst=hcg.get_sharding_parallel_group().ranks[param_rank], + # group=hcg.get_sharding_parallel_group(), + # sync_op=True, + # ) def _sharding_sync_parameters(self): """ @@ -180,7 +197,7 @@ class DygraphShardingOptimizer: self._rank2params[self._sharding_rank], ) ) - result = self._inner_optimizer.minimize( + result = self._inner_opt.minimize( loss, startup_program, parameters, no_grad_set ) @@ -192,19 +209,92 @@ class DygraphShardingOptimizer: def step(self): # TODO Check whether the model trainable param changed and update state accordingly - # actually updating - self._inner_optimizer.step() + # hack to grad_clip all parameters, + # otherwise the self._inner_opt will only grad_clip the self._rank2params[self._sharding_rank] params + # TODO(pangengzheng): remove the hacked grad_clip codes here when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. + origin_clip = self._inner_opt._grad_clip + if not isinstance(self._parameter_list[0], dict): + params_grads = [] + for param in self._parameter_list: + if ( + hasattr(param, "regularizer") + and param.regularizer is not None + ): + raise ValueError( + "param {} should not has the regularizer attribute".format( + param.name + ) + ) + if param.stop_gradient: + continue + grad_var = param._grad_ivar() + if hasattr(param, "main_grad") and param.main_grad is not None: + grad_var = param.main_grad + params_grads.append((param, grad_var)) + if hasattr(self._inner_opt._grad_clip, 'not_sharding_stage1'): + self._inner_opt._grad_clip.not_sharding_stage1 = False + params_grads = self._inner_opt._grad_clip(params_grads) + # set inner_opt._grad_clip None to avoid repeatedly grad_clip gradients inside inner_opt._apply_optimize + self._set_inner_opt_attr('_grad_clip', None) + update_param_names = [ + p.name for p in self._rank2params[self._sharding_rank] + ] + update_params_grads = [ + (p, g) for p, g in params_grads if p.name in update_param_names + ] + self._apply_optimize( + loss=None, + startup_program=None, + params_grads=update_params_grads, + ) + # restore the grad clip + self._set_inner_opt_attr('_grad_clip', origin_clip) # sync parameters across sharding ranks self._sharding_sync_parameters() - # TODO is it a good way to make _grad_clip a property - @property - def _grad_clip(self): - assert ( - self._inner_optimizer is not None - ), "inner opt of sharding is not initiliazed." - return self._inner_optimizer._grad_clip + @framework.dygraph_only + def set_state_dict(self, state_dict): + inner_state = {} + parameters = self._rank2params[self._sharding_rank] + + if "LR_Scheduler" in state_dict: + inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler") + + if "master_weights" in state_dict: + master = state_dict.pop("master_weights") + inner_state["master_weights"] = {} + for p in parameters: + for k, v in master.items(): + if p.name == k: + v.name = self._inner_opt._gen_master_weight_var_name(p) + inner_state["master_weights"][k] = v + + for p in parameters: + for k, v in state_dict.items(): + if p.name in k: + inner_state[k] = v + + self._inner_opt.set_state_dict(inner_state) + + def _set_inner_opt_attr(self, attr_name, value): + inner_opt = self._inner_opt + inner_opt_name = '_inner_opt' + if not isinstance(attr_name, str): + raise TypeError( + "attr_name should be str type, but is {}".format( + type(attr_name) + ) + ) + while hasattr(inner_opt, attr_name): + setattr(inner_opt, attr_name, value) + if ( + hasattr(inner_opt, inner_opt_name) + and getattr(inner_opt, inner_opt_name, None) is not None + ): + inner_opt = getattr(inner_opt, inner_opt_name, None) + else: + break def __getattr__(self, item): - return getattr(self._inner_optimizer, item) + return getattr(self._inner_opt, item) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py index bb24209aa1e28d614a8c9893ef97fe4d43c82d04..d8bf0510712debb82039b9003e958921dfb9c3f8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py @@ -18,13 +18,19 @@ import paddle from paddle import framework from paddle.autograd import no_grad from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( + DygraphShardingOptimizer, +) +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + obtain_optimizer_parameters_list, +) from paddle.framework import core from paddle.nn import ClipGradByGlobalNorm, clip from ...base.topology import ParallelMode from ...utils.hybrid_parallel_util import ( fused_allreduce_gradients, - sharding_reduce_gradients, + unwrap_optimizer, ) from ...utils.log_util import logger from ...utils.mix_precision_utils import MixPrecisionOptimizer @@ -32,24 +38,11 @@ from ...utils.mix_precision_utils import MixPrecisionOptimizer __all__ = [] -def _obtain_optimizer_parameters_list(optimizer): - if getattr(optimizer, '_param_groups', None) and isinstance( - optimizer._param_groups[0], dict - ): - parameters_list = [] - for group in optimizer._param_groups: - for param in group['params']: - parameters_list.append(param) - else: - parameters_list = list(optimizer._parameter_list) - - return parameters_list - - class HybridParallelClipGrad: def __init__(self, clip, hcg): self._clip = clip self._hcg = hcg + self.not_sharding_stage1 = True @no_grad() def _dygraph_clip(self, params_grads): @@ -166,8 +159,15 @@ class HybridParallelClipGrad: # add all reduce to get global norm of distributed params_and_grads if self._hcg.get_model_parallel_world_size() > 1: + sharding_flag = False + if ( + self._hcg.get_sharding_parallel_world_size() > 1 + and self._hcg.get_data_parallel_world_size() == 1 + ): + sharding_flag = True paddle.distributed.all_reduce( - global_norm_var_dist, group=self._hcg.get_check_parallel_group() + global_norm_var_dist, + group=self._hcg.get_check_parallel_group(sharding_flag), ) # add all reduce to get global norm of non-distributed params_and_grads in groups of pp @@ -179,7 +179,11 @@ class HybridParallelClipGrad: # In Sharding mode, param and grad is mapping different rank in optimizer. # ClipGradByGlobalNorm need allreduce to get globol norm - if self._hcg.get_sharding_parallel_world_size() > 1: + # TODO(pangengzheng): remove the self.not_sharding_stage1 flag when there is no diff in calculating global norm values in HybridParallelClipGrad compared to dp. + if ( + self._hcg.get_sharding_parallel_world_size() > 1 + and self.not_sharding_stage1 + ): paddle.distributed.all_reduce( global_norm_var_not_dist, group=self._hcg.get_sharding_parallel_group(), @@ -238,6 +242,10 @@ class HybridParallelClipGrad: class HybridParallelOptimizer: # adapter wrapper for optimizer def __init__(self, optimizer, hcg, strategy): + # Note: Only sharding stage 1 is considered in HybridParallelOptimizer. + # The sharding stage2 and stage3 optimizers are invoked in other api. + if hcg.get_sharding_parallel_world_size() > 1: + optimizer = DygraphShardingOptimizer(optimizer, hcg) self._inner_opt = optimizer self._strategy = strategy self._hcg = hcg @@ -263,15 +271,11 @@ class HybridParallelOptimizer: "or Sharding, the grad clip of original optimizer will be changed." ) - inner_opt = ( - self._inner_opt._inner_optimizer - if self._sharding_enable - else self._inner_opt + inner_opt = unwrap_optimizer( + self._inner_opt, + (MixPrecisionOptimizer, DygraphShardingOptimizer), ) - if isinstance(inner_opt, MixPrecisionOptimizer): - inner_opt = inner_opt._inner_opt - if ( inner_opt._parameter_list and not isinstance(inner_opt._parameter_list[0], dict) @@ -415,9 +419,10 @@ class HybridParallelOptimizer: @no_grad() @framework.dygraph_only def step(self): - parameters_list = _obtain_optimizer_parameters_list(self._inner_opt) + parameters_list = obtain_optimizer_parameters_list(self._inner_opt) if self._sharding_enable: - sharding_reduce_gradients(list(parameters_list), self._hcg) + assert isinstance(self._inner_opt, DygraphShardingOptimizer) + self._inner_opt.reduce_gradients(list(parameters_list), self._hcg) if self._dp_enable: fused_allreduce_gradients(list(parameters_list), self._hcg) @@ -433,12 +438,13 @@ class HybridParallelOptimizer: parameter_list = ( parameters if parameters - else _obtain_optimizer_parameters_list(self._inner_opt) + else obtain_optimizer_parameters_list(self._inner_opt) ) # Here sharding should use global parameter list if self._sharding_enable: - sharding_reduce_gradients(list(parameter_list), self._hcg) + assert isinstance(self._inner_opt, DygraphShardingOptimizer) + self._inner_opt.reduce_gradients(list(parameter_list), self._hcg) if self._dp_enable: fused_allreduce_gradients(list(parameter_list), self._hcg) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index c7252de68a6fdaf10fd47a83a0ce6685183bd759..760fbc1d72af1b49c7b4f4a3f880f73796ff6ed2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -246,17 +246,21 @@ class FusedCommBuffer: def _comm_grads(self): assert self._all_params_checked_in - if self._act == HOOK_ACTION.ALL_REDUCE: - task = paddle.distributed.all_reduce( - self.grad_storage, group=self._comm_group, sync_op=False - ) - elif self._act == HOOK_ACTION.REDUCE: - task = paddle.distributed.reduce( - self.grad_storage, - dst=self._dst, - group=self._comm_group, - sync_op=False, - ) + # Note: after sharding change to reduce operation here also need to be updated + # if self._act == HOOK_ACTION.ALL_REDUCE: + # task = paddle.distributed.all_reduce( + # self.grad_storage, group=self._comm_group, sync_op=False + # ) + # elif self._act == HOOK_ACTION.REDUCE: + # task = paddle.distributed.reduce( + # self.grad_storage, + # dst=self._dst, + # group=self._comm_group, + # sync_op=False, + # ) + task = paddle.distributed.all_reduce( + self.grad_storage, group=self._comm_group, sync_op=False + ) self._task = task @imperative_base.no_grad diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index 647f7f92673258ef37ca758920e31851c0dff7aa..340ace6ed7b800db99d26e4fb1bf8ee05c9785b1 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -29,6 +29,20 @@ from .log_util import logger __all__ = [] +def obtain_optimizer_parameters_list(optimizer): + if getattr(optimizer, '_param_groups', None) and isinstance( + optimizer._param_groups[0], dict + ): + parameters_list = [] + for group in optimizer._param_groups: + for param in group['params']: + parameters_list.append(param) + else: + parameters_list = list(optimizer._parameter_list) + + return parameters_list + + def _apply_collective_grads(parameters, comm_group, bucket_size, scale=None): grad_var_set = set() grad_vars = [] @@ -230,30 +244,6 @@ def fused_allreduce_gradients(parameter_list, hcg): fused_allreduce_gradients_with_group(parameter_list, data_parallel_group) -def sharding_reduce_gradients(parameter_list, hcg): - # TODO allreduce --> reduce - # TODO merge grad / nrank with dp - logger.debug("sharding start gradients sync") - with framework.no_grad(): - sharding_nrank = hcg.get_sharding_parallel_group().nranks - for param in parameter_list: - g_var = None - if param.trainable and (param._grad_ivar() is not None): - g_var = param._grad_ivar() - if param.trainable and hasattr(param, "main_grad"): - assert ( - param._grad_ivar() is None - ), "param.grad should be None when using main_grad" - g_var = param.main_grad - if g_var is not None: - g_var.scale_(1.0 / sharding_nrank) - paddle.distributed.all_reduce( - g_var, - group=hcg.get_sharding_parallel_group(), - sync_op=True, - ) - - def broadcast_sharding_parameters(model, hcg): # TODO TO save memory, use un-fused broadcast to avoid potentional OOM logger.debug("sharding start init parameters sync") @@ -262,3 +252,10 @@ def broadcast_sharding_parameters(model, hcg): sync_params_buffers( model, sharding_parallel_group, src_rank, is_model_parallel=False ) + + +def unwrap_optimizer(optimizer, optimizer_instances=()): + _inner_opt = optimizer + while isinstance(_inner_opt, optimizer_instances): + _inner_opt = _inner_opt._inner_opt + return _inner_opt diff --git a/python/paddle/distributed/fleet/utils/mix_precision_utils.py b/python/paddle/distributed/fleet/utils/mix_precision_utils.py index 2d84d39e8ffee03341d034c6b687b309695eb433..1e25879b820783231623efd70a13b0e69d0a7ae4 100644 --- a/python/paddle/distributed/fleet/utils/mix_precision_utils.py +++ b/python/paddle/distributed/fleet/utils/mix_precision_utils.py @@ -21,6 +21,9 @@ import numpy as np import paddle from paddle import _legacy_C_ops, nn from paddle.distributed import fleet +from paddle.distributed.fleet.utils.hybrid_parallel_util import ( + obtain_optimizer_parameters_list, +) from paddle.fluid import framework from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import to_variable @@ -93,20 +96,7 @@ class MixPrecisionLayer(nn.Layer): class MixPrecisionOptimizer: def __init__(self, optimizer): self._inner_opt = optimizer - self._parameter_list = self._obtain_optimizer_parameters_list() - - def _obtain_optimizer_parameters_list(self): - if getattr(self._inner_opt, '_param_groups', None) and isinstance( - self._inner_opt._param_groups[0], dict - ): - parameters_list = [] - for group in self._inner_opt._param_groups: - for param in group['params']: - parameters_list.append(param) - else: - parameters_list = list(self._inner_opt._parameter_list) - - return parameters_list + self._parameter_list = obtain_optimizer_parameters_list(optimizer) @imperative_base.no_grad @framework.dygraph_only diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 12fcd90c67dcea1b3c7fd0d030b0b0e6a6b74e71..5becbc8cec22c71e98d297841326cb301c53a986 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -700,8 +700,7 @@ class Optimizer: else: assert isinstance(self.helper, LayerHelper) - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) + var_name = self._gen_master_weight_var_name(param) var = paddle.static.create_global_var( name=var_name, shape=param.shape, @@ -722,6 +721,10 @@ class Optimizer: self._master_weights[param.name] = var return var + def _gen_master_weight_var_name(self, param): + var_name = param.name + "_fp32_master" + return unique_name.generate(var_name) + def _create_master_grad(self, grad): assert self._is_dtype_fp16_or_bf16(grad.dtype) if grad.name in self._master_grads: diff --git a/test/collective/fleet/hybrid_parallel_sharding_model.py b/test/collective/fleet/hybrid_parallel_sharding_model.py index 435876002db26d589c985859bb5bd25a4b5c35b8..bb1ace6a6a47653f272174dc456cdea3e06bb768 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_model.py +++ b/test/collective/fleet/hybrid_parallel_sharding_model.py @@ -23,6 +23,10 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import ( DygraphShardingOptimizer, ) +from paddle.distributed.fleet.utils.mix_precision_utils import ( + MixPrecisionLayer, + MixPrecisionOptimizer, +) vocab_size = 20 hidden_size = 10 @@ -210,47 +214,24 @@ class TestDistMPTraning(unittest.TestCase): optimizer.clear_grad() return loss - def build_optimizer( - self, model, strategy=None, is_sharding=True, Optimizer="adam" - ): + def build_optimizer(self, model, strategy=None, Optimizer="adam"): clip = paddle.nn.ClipGradByGlobalNorm(0.5) if Optimizer == "adam": - if is_sharding: - optimizer = DygraphShardingOptimizer( - hcg=fleet.get_hybrid_communicate_group(), - user_defined_strategy=strategy, - params=model.parameters(), - inner_optimizer_class=paddle.optimizer.AdamW, - learning_rate=0.001, - weight_decay=0.00001, - grad_clip=clip, - ) - else: - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.001, - weight_decay=0.00001, - grad_clip=clip, - ) + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.001, + weight_decay=0.00001, + grad_clip=clip, + ) else: - if is_sharding: - optimizer = DygraphShardingOptimizer( - hcg=fleet.get_hybrid_communicate_group(), - user_defined_strategy=strategy, - params=model.parameters(), - inner_optimizer_class=paddle.optimizer.Momentum, - learning_rate=0.001, - grad_clip=clip, - ) - else: - optimizer = paddle.optimizer.Momentum( - learning_rate=0.001, - parameters=model.parameters(), - grad_clip=clip, - ) + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + parameters=model.parameters(), + grad_clip=clip, + ) return optimizer - def build_model_optimizer(self, Optimizer="adam"): + def build_model_optimizer(self, Optimizer="adam", amp_level=None): hcg = fleet.get_hybrid_communicate_group() word_size = hcg.get_model_parallel_world_size() sharding_id = hcg.get_sharding_parallel_rank() @@ -266,11 +247,8 @@ class TestDistMPTraning(unittest.TestCase): optimizer_a = self.build_optimizer( model_a, strategy=self.strategy, - is_sharding=True, Optimizer=Optimizer, ) - model_a = fleet.distributed_model(model_a) - optimizer_a = fleet.distributed_optimizer(optimizer_a) model_b = SimpleDPNet( vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 @@ -278,15 +256,23 @@ class TestDistMPTraning(unittest.TestCase): optimizer_b = self.build_optimizer( model_b, strategy=self.strategy, - is_sharding=False, Optimizer=Optimizer, ) + if amp_level is not None and amp_level == "O2": + model_a = MixPrecisionLayer(model_a) + optimizer_a = MixPrecisionOptimizer(optimizer_a) + model_b = MixPrecisionLayer(model_b) + optimizer_b = MixPrecisionOptimizer(optimizer_b) + + model_a = fleet.distributed_model(model_a) + optimizer_a = fleet.distributed_optimizer(optimizer_a) + return model_a, optimizer_a, model_b, optimizer_b - def sharding_model(self, Optimizer, sharded_accumulators): + def sharding_model(self, Optimizer, sharded_accumulators, amp_level=None): model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer( - Optimizer=Optimizer + Optimizer=Optimizer, amp_level=amp_level ) self.assertTrue( @@ -296,9 +282,7 @@ class TestDistMPTraning(unittest.TestCase): for idx in range(STEPS): if idx == 2 and paddle.distributed.get_rank() == 0: self.assertTrue( - set( - optimizer_a._inner_opt._inner_optimizer.state_dict().keys() - ) + set(optimizer_a._inner_opt._inner_opt.state_dict().keys()) == sharded_accumulators ) @@ -352,6 +336,19 @@ class TestDistMPTraning(unittest.TestCase): Optimizer="Momentum", sharded_accumulators=sharded_accumulators ) + def test_sharding_momentum_amp(self): + sharded_accumulators = { + 'linear_12.w_0_velocity_0', + 'linear_13.b_0_velocity_0', + 'linear_14.b_0_velocity_0', + 'embedding_4.w_0_velocity_0', + } + self.sharding_model( + Optimizer="Momentum", + sharded_accumulators=sharded_accumulators, + amp_level="O2", + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/fleet/hybrid_parallel_sharding_state_dict.py b/test/collective/fleet/hybrid_parallel_sharding_state_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a5f55dc188ab40d1217a2690f5d2718eb37ed8 --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_sharding_state_dict.py @@ -0,0 +1,276 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet +from paddle.distributed.fleet.utils.mix_precision_utils import ( + MixPrecisionOptimizer, +) + +vocab_size = 20 +hidden_size = 10 +inner_size = 8 +output_size = 10 +seq_length = 2 +batch_size = 4 +STEPS = 10 + + +class SimpleDPNet(paddle.nn.Layer): + def __init__( + self, vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ): + super().__init__() + self.linear1 = paddle.nn.Linear( + hidden_size, + inner_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc1) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear2 = paddle.nn.Linear( + inner_size, + hidden_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(np_fc2) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.linear3 = paddle.nn.Linear( + hidden_size, + output_size, + weight_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + bias_attr=paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Constant(0.0) + ), + ) + + self.embedding = paddle.nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.nn.initializer.Constant(value=0.5), + ) + + def forward(self, x): + x = self.embedding(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = paddle.matmul(x, self.embedding.weight, transpose_y=True) + return x + + +class TestDistShardingTraining(unittest.TestCase): + def setUp(self): + random.seed(2021) + np.random.seed(2021) + paddle.seed(2021) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 2, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + } + fleet.init(is_collective=True, strategy=self.strategy) + self.data = [ + np.random.randint( + 0, + vocab_size, + ( + batch_size, + seq_length, + ), + ) + for _ in range(STEPS) + ] + + def build_adam_optimizer(self, model, lr=0.001): + clip = paddle.nn.ClipGradByGlobalNorm(0.5) + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=lr, + weight_decay=0.00001, + grad_clip=clip, + ) + return optimizer + + def test_set_state_dict(self): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + init_lr = 0.001 + init_lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=init_lr, T_max=1 + ) + local_optimizer = self.build_adam_optimizer(model, init_lr_scheduler) + dist_optimizer = fleet.distributed_optimizer(local_optimizer) + # prepare state_dict + state_dict = {} + # lr_scheduler + base_lr = 0.1 + lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay( + learning_rate=base_lr, T_max=1 + ) + state_dict["LR_Scheduler"] = lr_scheduler.state_dict() + # master_weights and accumulators + state_dict["master_weights"] = {} + all_param_names = [] + accumulator_names = ["moment1", "moment2"] + # + local_params = dist_optimizer._rank2params[ + dist_optimizer._sharding_rank + ] + local_param_names = [p.name for p in local_params] + local_acc_names = [] + other_acc_names = [] + for p in model.parameters(): + var_name = dist_optimizer._gen_master_weight_var_name(p) + var = paddle.static.create_global_var( + name=var_name, + shape=p.shape, + value=0, + dtype='float32', + persistable=True, + ) + var = paddle.randn(shape=var.shape, dtype=var.dtype, name=var.name) + state_dict["master_weights"][p.name] = var + # accumulator + for name in accumulator_names: + acc_name = p.name + '_' + name + state_dict[acc_name] = paddle.randn( + shape=var.shape, dtype=var.dtype, name=acc_name + ) + if p.name in local_param_names: + local_acc_names.append(acc_name) + else: + other_acc_names.append(acc_name) + all_param_names.append(p.name) + # test api + tmp_state_dict = copy.deepcopy(state_dict) + dist_optimizer.set_state_dict(state_dict) + # check result + other_param_names = [ + p_name + for p_name in all_param_names + if p_name not in local_param_names + ] + inner_opt = dist_optimizer._inner_opt + self.assertEqual(inner_opt._learning_rate.last_lr, base_lr) + assert hasattr(inner_opt, "_master_weights") + for p_name, weight in inner_opt._master_weights.items(): + assert p_name in local_param_names + assert p_name not in other_param_names + assert p_name in tmp_state_dict["master_weights"] + np.testing.assert_array_almost_equal( + weight.numpy(), tmp_state_dict["master_weights"][p_name].numpy() + ) + for acc_name, val in inner_opt._accumulators_holder.items(): + assert acc_name in local_acc_names + assert acc_name not in other_acc_names + assert acc_name in tmp_state_dict + np.testing.assert_array_almost_equal( + val.numpy(), tmp_state_dict[acc_name].numpy() + ) + + def test_clear_grad(self): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + local_optimizer = self.build_adam_optimizer(model) + dist_optimizer = fleet.distributed_optimizer(local_optimizer) + + tmp_parameter_list = [] + for p in dist_optimizer._inner_opt._parameter_list: + main_grad = paddle.randn(shape=p.shape, dtype=p.dtype, name=p.name) + p.main_grad = main_grad + tmp_parameter_list.append(p) + + assert hasattr( + dist_optimizer._inner_opt._parameter_list[0], "main_grad" + ) + # test set_to_zero True + dist_optimizer._inner_opt.clear_grad(set_to_zero=True) + for p in dist_optimizer._inner_opt._parameter_list: + np.testing.assert_array_almost_equal( + p.main_grad.numpy(), np.zeros(p.main_grad.numpy().shape) + ) + # test set_to_zero False + dist_optimizer._inner_opt.clear_grad(set_to_zero=False) + for p in dist_optimizer._inner_opt._parameter_list: + self.assertTrue(p.main_grad is None) + + def test_set_inner_opt_attr(self): + np_fc1 = np.random.random_sample((hidden_size, inner_size)) + np_fc2 = np.random.random_sample((inner_size, hidden_size)) + + model = SimpleDPNet( + vocab_size, hidden_size, inner_size, output_size, np_fc1, np_fc2 + ) + + local_optimizer = self.build_adam_optimizer(model) + local_optimizer = MixPrecisionOptimizer(local_optimizer) + dist_optimizer = fleet.distributed_optimizer(local_optimizer) + sharding_opt = dist_optimizer._inner_opt + sharding_opt._set_inner_opt_attr('_parameter_list', 123) + self.assertTrue(hasattr(sharding_opt._inner_opt, '_parameter_list')) + self.assertTrue( + hasattr(sharding_opt._inner_opt._inner_opt, '_parameter_list') + ) + self.assertEqual(sharding_opt._inner_opt._parameter_list, 123) + self.assertEqual( + sharding_opt._inner_opt._inner_opt._parameter_list, 123 + ) + + sharding_opt._set_inner_opt_attr('_param_groups', 123) + self.assertTrue(hasattr(sharding_opt._inner_opt, '_param_groups')) + self.assertTrue( + hasattr(sharding_opt._inner_opt._inner_opt, '_param_groups') + ) + self.assertEqual(sharding_opt._inner_opt._param_groups, 123) + self.assertEqual(sharding_opt._inner_opt._inner_opt._param_groups, 123) + + # test bad case + try: + sharding_opt._set_inner_opt_attr(123, 123) + self.assertTrue(False) + except: + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py index e2a552087bde6cf603fe01bfe188905a6391dfa6..857093ee7b44c31f6b03ad95053d764a04d4dd8d 100644 --- a/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py +++ b/test/collective/fleet/test_parallel_dygraph_sharding_parallel.py @@ -22,6 +22,9 @@ class TestHybridParallel(TestMultipleGpus): def test_hybrid_parallel_sharding_logic(self): self.run_mnist_2gpu('hybrid_parallel_sharding_model.py') + def test_hybrid_parallel_sharding_state_dict(self): + self.run_mnist_2gpu('hybrid_parallel_sharding_state_dict.py') + if __name__ == "__main__": unittest.main()