# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import paddle from paddle.fluid import core, unique_name from paddle.fluid.executor import global_scope from paddle.fluid.framework import Variable, name_scope from paddle.fluid.layer_helper import LayerHelper from paddle.nn import ClipGradByGlobalNorm from paddle.optimizer import Optimizer def init_communicator(block, rank, ranks, ring_id): eps = os.environ['PADDLE_TRAINER_ENDPOINTS'] eps = [ep.strip() for ep in eps.split(",") if ep.strip()] cur_ep = eps[rank] other_eps = [eps[r] for r in ranks if r != rank] local_rank = ranks.index(rank) comm_var_name = unique_name.generate('comm_id') comm_id_var = block.create_var( name=comm_var_name, persistable=True, type=core.VarDesc.VarType.RAW ) if core.is_compiled_with_cuda(): block.append_op( type='c_gen_nccl_id', inputs={}, outputs={'Out': comm_id_var}, attrs={ 'rank': local_rank, 'endpoint': cur_ep, 'other_endpoints': other_eps, 'ring_id': ring_id, }, ) elif core.is_compiled_with_xpu(): block.append_op( type='c_gen_bkcl_id', inputs={}, outputs={'Out': comm_id_var}, attrs={ 'rank': local_rank, 'endpoint': cur_ep, 'other_endpoints': other_eps, 'ring_id': ring_id, }, ) elif ( paddle.distributed.ParallelEnv().device_type in paddle.device.get_all_custom_device_type() ): block.append_op( type='c_gen_xccl_id', inputs={}, outputs={'Out': comm_id_var}, attrs={ 'rank': local_rank, 'endpoint': cur_ep, 'other_endpoints': other_eps, 'ring_id': ring_id, }, ) block.append_op( type='c_comm_init', inputs={'X': comm_id_var}, outputs={}, attrs={'nranks': len(ranks), 'rank': local_rank, 'ring_id': ring_id}, ) tmp_var = block.create_var(name=unique_name.generate('tmp')) block.append_op( type='fill_constant', outputs={'Out': tmp_var}, attrs={'value': 1} ) block.append_op( type='c_allreduce_sum', inputs={'X': tmp_var}, outputs={'Out': tmp_var}, attrs={'ring_id': ring_id, 'use_calc_stream': True}, ) block.append_op( type='c_sync_calc_stream', inputs={'X': tmp_var}, outputs={'Out': tmp_var}, ) return ring_id def broadcast_parameters(block, parameters, ring_id): for p in parameters: block.append_op( type='c_broadcast', inputs={'X': p}, outputs={'Out': p}, attrs={'ring_id': ring_id, 'use_calc_stream': True}, ) class DistributedFusedLamb(Optimizer): def __init__( self, learning_rate=0.001, lamb_weight_decay=0.01, beta1=0.9, beta2=0.999, epsilon=1e-6, parameters=None, grad_clip=None, exclude_from_weight_decay_fn=None, clip_after_allreduce=True, is_grad_scaled_by_nranks=True, alignment=128, use_master_param_norm=True, gradient_accumulation_steps=1, use_master_acc_grad=True, nproc_per_node=None, use_hierarchical_allreduce=False, name=None, ): assert ( not paddle.in_dynamic_mode() ), "DistributedFusedLamb does not support dygraph mode" super().__init__(learning_rate=learning_rate, grad_clip=None, name=name) self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon self._weight_decay = ( lamb_weight_decay if lamb_weight_decay is not None else 0.0 ) if grad_clip is not None: assert isinstance( grad_clip, ClipGradByGlobalNorm ), "Only ClipGradByGlobalNorm is supported in DistributedFusedLamb" max_global_grad_norm = grad_clip.clip_norm else: max_global_grad_norm = -1.0 self._max_global_grad_norm = max_global_grad_norm self._alignment = alignment if alignment is not None else -1 self._clip_after_allreduce = clip_after_allreduce self._is_grad_scaled_by_nranks = is_grad_scaled_by_nranks self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn self._scale = None self._use_master_param_norm = use_master_param_norm self._gradient_accumulation_steps = gradient_accumulation_steps self._use_master_acc_grad = use_master_acc_grad self._nproc_per_node = nproc_per_node self._use_hierarchical_allreduce = use_hierarchical_allreduce assert self._gradient_accumulation_steps >= 1 self.helper = LayerHelper('distributed_fused_lamb') self._supports_check_nan_inf = True # very import flag for AMP main_block = self.helper.main_program.global_block() self._found_inf = main_block.create_var( name=unique_name.generate('found_inf'), shape=[1], dtype=core.VarDesc.VarType.BOOL, ) self._step = None if self._gradient_accumulation_steps > 1: self._stop_update = main_block.create_var( name=unique_name.generate('stop_update'), shape=[1], dtype=core.VarDesc.VarType.BOOL, ) else: self._stop_update = None self._param_to_master_param = {} def _get_stop_update_var(self): return self._stop_update if self._stop_update is not None else False def _set_step(self, step): self._step = step def _get_or_create_step(self): if self._step is None: self._step = self._create_persistable_var('step', dtype='int64') return self._step def _set_scale(self, scale): assert scale is not None if not isinstance(scale, Variable): scale = self._create_scale_from_constant(scale) self._scale = scale def _create_scale_from_constant(self, value): name = unique_name.generate('global_scale') return paddle.static.create_global_var( name=name, shape=[1], dtype='float32', value=float(value), persistable=True, ) def _get_or_create_scale(self): if self._scale is None: self._scale = self._create_scale_from_constant(1.0) return self._scale def _create_persistable_var(self, name=None, shape=[-1], dtype='float32'): startup_block = self.helper.startup_program.global_block() if name is not None: name = unique_name.generate(name) startup_var = startup_block.create_var( name=name, shape=shape, dtype=dtype, persistable=True, stop_gradient=True, ) main_block = self.helper.main_program.global_block() main_var = main_block.create_var( name=startup_var.name, shape=startup_var.shape, dtype=startup_var.dtype, persistable=True, stop_gradient=True, ) return main_var def _get_parameter(self, name, scope=None): if scope is None: scope = global_scope() master_param = self._param_to_master_param.get(name) assert master_param is not None master_param_t = scope.find_var(master_param).get_tensor() assert master_param_t._dtype() == core.VarDesc.VarType.FP32 param_t = scope.find_var(name).get_tensor() if param_t._dtype() == core.VarDesc.VarType.FP32: assert param_t._ptr() == master_param_t._ptr() return param_t, None else: assert param_t._dtype() == core.VarDesc.VarType.FP16 assert param_t.shape() == master_param_t.shape() return param_t, master_param_t def apply_optimize(self, params_grads): self.apply_gradients(params_grads) def apply_gradients(self, params_grads): flattened = [] for p, g in params_grads: flattened.extend([p, g]) with flattened[0].block.program._optimized_guard(flattened), name_scope( "optimizer" ): self._apply_gradients_impl(params_grads) def _apply_gradients_impl(self, params_grads): for p, g in params_grads: assert ( g.type == core.VarDesc.VarType.LOD_TENSOR ), "Only support dense gradient" g.persistable = True # the gradient must be persistable for fusion fp32_fused_param = self._create_persistable_var('fp32_fused_param') fp32_fused_grad = self._create_persistable_var('fp32_fused_grad') fp16_fused_param = self._create_persistable_var( 'fp16_fused_param', dtype='float16' ) fp16_fused_grad = self._create_persistable_var( 'fp16_fused_grad', dtype='float16' ) master_params = [] for p, g in params_grads: master_p = self._create_persistable_var('master_weight') self._param_to_master_param[p.name] = master_p.name master_params.append(master_p) moment1 = self._create_persistable_var('moment1') moment1.is_distributed = True moment2 = self._create_persistable_var('moment2') moment2.is_distributed = True beta1pow = self._create_persistable_var('beta1pow') beta2pow = self._create_persistable_var('beta2pow') param_info = self._create_persistable_var('param_info', dtype='int32') param_info.is_distributed = True fused_offsets = self._create_persistable_var( 'fused_offsets', dtype='int32' ) fp32_partial_fused_offsets = self._create_persistable_var( 'fp32_partial_fused_offsets', dtype='int32' ) fp32_partial_fused_offsets.is_distributed = True fp16_partial_fused_offsets = self._create_persistable_var( 'fp16_partial_fused_offsets', dtype='int32' ) fp16_partial_fused_offsets.is_distributed = True param_order = self._create_persistable_var('param_order', dtype='int32') param_order.is_distributed = True if self._gradient_accumulation_steps > 1: fp32_acc_fused_grad = [ self._create_persistable_var('fp32_acc_fused_grad') ] fp16_acc_fused_grad = [ self._create_persistable_var( 'fp16_acc_fused_grad', dtype='float16' ) ] acc_step = [self._create_persistable_var('acc_step', dtype='int64')] else: fp32_acc_fused_grad = [] fp16_acc_fused_grad = [] acc_step = [] step = self._get_or_create_step() rank = paddle.distributed.get_rank() nranks = paddle.distributed.get_world_size() if self._nproc_per_node is None: nproc_per_node = nranks else: nproc_per_node = self._nproc_per_node assert ( nranks % nproc_per_node == 0 ), "nranks should be exactly divided by nproc_per_node" shard_inside_node = nranks > nproc_per_node local_rank = rank % nproc_per_node node_id = int(rank / nproc_per_node) node_num = int(nranks / nproc_per_node) ring_ids = [] startup_block = self.helper.startup_program.global_block() if nranks > 1: ring_id = init_communicator( startup_block, rank, list(range(nranks)), 0 ) ring_ids.append(ring_id) use_hierarchical_allreduce = False if node_num > 1 and len(ring_ids) <= 1 and shard_inside_node: local_group_ranks = list( range(node_id * nproc_per_node, (node_id + 1) * nproc_per_node) ) ring_id = init_communicator( startup_block, rank, local_group_ranks, 1 ) ring_ids.append(ring_id) if self._use_hierarchical_allreduce and nranks > nproc_per_node: use_hierarchical_allreduce = True outer_group_ranks = list( range(rank % nproc_per_node, nranks, nproc_per_node) ) ring_id = init_communicator( startup_block, rank, outer_group_ranks, ring_ids[-1] + 1 ) ring_ids.append(ring_id) scale = self._get_or_create_scale() params = [p for p, _ in params_grads] grads = [g for _, g in params_grads] apply_weight_decay = [1] * len(params) if self._exclude_from_weight_decay_fn is not None: for i, p in enumerate(params): if self._exclude_from_weight_decay_fn(p): apply_weight_decay[i] = 0 for g in grads: startup_block.create_var( name=g.name, type=g.type, dtype=g.dtype, persistable=g.persistable, shape=g.shape, ) if nranks > 1: broadcast_parameters(startup_block, params, ring_ids[0]) startup_block.append_op( type='distributed_fused_lamb_init', inputs={ 'Param': params, 'Grad': grads, }, outputs={ 'FP32FusedParam': [fp32_fused_param], 'FP32FusedGrad': [fp32_fused_grad], 'FP16FusedParam': [fp16_fused_param], 'FP16FusedGrad': [fp16_fused_grad], 'Moment1': [moment1], 'Moment2': [moment2], 'Beta1Pow': [beta1pow], 'Beta2Pow': [beta2pow], 'GlobalScale': [scale], 'ParamInfo': [param_info], 'ParamOut': params, 'MasterParamOut': master_params, 'GradOut': grads, 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'FusedParamOffsets': [fused_offsets], 'ParamOrder': [param_order], 'Step': [step], }, attrs={ 'alignment': self._alignment, 'rank': local_rank if shard_inside_node else rank, 'nranks': nproc_per_node if shard_inside_node else nranks, 'apply_weight_decay': apply_weight_decay, 'moment1': 0.0, 'moment2': 0.0, 'beta1': self._beta1, 'beta2': self._beta2, }, ) main_block = self.helper.main_program.global_block() self._create_global_learning_rate() lr = None for p_g in params_grads: if lr is None: lr = self._create_param_lr(p_g) else: new_lr = self._create_param_lr(p_g) assert id(lr) == id( new_lr ), "The learning rate for each parameter should be the same" assert lr is not None lamb_op = main_block.append_op( type='distributed_fused_lamb', inputs={ 'FP32FusedParam': [fp32_fused_param], 'FP32FusedGrad': [fp32_fused_grad], 'FP16FusedParam': [fp16_fused_param], 'FP16FusedGrad': [fp16_fused_grad], 'LearningRate': [lr], 'Moment1': [moment1], 'Moment2': [moment2], 'Beta1Pow': [beta1pow], 'Beta2Pow': [beta2pow], 'GlobalScale': [scale], 'ParamInfo': [param_info], 'Param': params, 'Grad': grads, 'FusedParamOffsets': [fused_offsets], 'FP32ShardFusedParamOffsets': [fp32_partial_fused_offsets], 'FP16ShardFusedParamOffsets': [fp16_partial_fused_offsets], 'ParamOrder': [param_order], }, outputs={ 'FP32FusedParamOut': [fp32_fused_param], 'FP16FusedParamOut': [fp16_fused_param], 'Moment1Out': [moment1], 'Moment2Out': [moment2], 'Beta1PowOut': [beta1pow], 'Beta2PowOut': [beta2pow], 'ParamOut': params, 'GradOut': grads, 'FoundInf': [self._found_inf], 'FP32AccFusedGrad': fp32_acc_fused_grad, 'FP16AccFusedGrad': fp16_acc_fused_grad, 'AccStep': acc_step, 'StopUpdate': self._stop_update if self._stop_update is not None else [], 'Step': [step], }, attrs={ 'weight_decay': self._weight_decay, 'beta1': self._beta1, 'beta2': self._beta2, 'epsilon': self._epsilon, 'max_global_grad_norm': self._max_global_grad_norm, 'clip_after_allreduce': self._clip_after_allreduce, 'rank': rank, 'nranks': nranks, 'ring_ids': ring_ids, 'use_master_param_norm': self._use_master_param_norm, 'is_grad_scaled_by_nranks': self._is_grad_scaled_by_nranks, 'acc_steps': self._gradient_accumulation_steps, 'use_master_acc_grad': self._use_master_acc_grad, 'use_hierarchical_allreduce': use_hierarchical_allreduce, }, ) return [lamb_op]