diff --git a/README.md b/README.md index 61520664b42e0db44165e6ef3c927a420676ac2e..eed70fa5e04f13d86535833d70aff6096b352814 100755 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis * Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters. * Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models. * Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers. -* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. +* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks. Early adopters of DeepSpeed have already produced a language model (LM) with over 17B parameters called @@ -33,6 +33,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale) # News +* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/) * [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) * [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/) * [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed) @@ -119,7 +120,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage. * Memory- and compute-efficient sparse kernels * Support 10x longer sequences than dense * Flexible support to different sparse structures -* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) +* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/) * Custom communication collective * Up to 5x communication volume saving * [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations) @@ -192,7 +193,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information 4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840). 5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888). 6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857). - +7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069). # Videos 1. DeepSpeed KDD 2020 Tutorial diff --git a/deepspeed/runtime/comm/nccl.py b/deepspeed/runtime/comm/nccl.py index 0ac2646bd0d7d05afd7785f27b153925469cbbc4..e8bd03514a1b222c237c7daf890aad5ff5ec8bbd 100644 --- a/deepspeed/runtime/comm/nccl.py +++ b/deepspeed/runtime/comm/nccl.py @@ -12,8 +12,12 @@ from deepspeed.runtime.compression.cupy import CupyBackend class NcclBackend(object): - def __init__(self): - self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + def __init__(self, mpu=None): + if mpu is None: + self.world_group = dist.new_group(ranks=range(dist.get_world_size())) + else: + self.mpu = mpu + self.world_group = self.mpu.get_data_parallel_group() self.rank = dist.get_rank(group=self.world_group) self.size = dist.get_world_size(group=self.world_group) self.compression_backend = CupyBackend() @@ -92,9 +96,11 @@ class NcclBackend(object): # communication phase 1 # gather_start = time.time() # Alltoall for sign - dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed)) + dist.all_to_all_single(recvbuf_sign, + torch.stack(sign_list_packed), + group=self.world_group) # Allgather for scale - dist.all_gather(recvbuf_scale, worker_scale) + dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) # gather_end = time.time() @@ -151,8 +157,10 @@ class NcclBackend(object): ] # Communication Phase 2 - dist.all_gather(recvbuf_sign_server, server_sign_packed[0]) - dist.all_gather(recvbuf_scale_server, server_scale) + dist.all_gather(recvbuf_sign_server, + server_sign_packed[0], + group=self.world_group) + dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) cupy_server_sign_packed = None diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 9e33876994f97aca86d78899899a034a0fe3472f..3fa0b32a6032c5e2adfb8d02bec4df43b4c0ab99 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -32,11 +32,13 @@ ADAM_OPTIMIZER = 'adam' ADAMW_OPTIMIZER = 'adamw' LAMB_OPTIMIZER = 'lamb' ONEBIT_ADAM_OPTIMIZER = 'onebitadam' +ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' DEEPSPEED_OPTIMIZERS = [ ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, + ONEBIT_LAMB_OPTIMIZER, ] # extra optimizer parameters for adam/adamw diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f71a7324585ada53dbc92d0b00bc1d9b2653e2ad..6a857bca378c28ddc68eb7bc331dddcf80eb3596 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,7 +24,7 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activati from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ - ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \ + ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT from deepspeed.runtime.dataloader import DeepSpeedDataLoader @@ -553,7 +553,8 @@ class DeepSpeedEngine(Module): assert self._is_supported_optimizer(self.optimizer_name()), \ '{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name()) - if self.optimizer_name() == LAMB_OPTIMIZER: + if self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name( + ) == ONEBIT_LAMB_OPTIMIZER: assert self.dynamic_loss_scale(), \ 'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name()) @@ -694,6 +695,13 @@ class DeepSpeedEngine(Module): logger.warning( f'Currently the convergence of 1-bit Adam is only verified under FP16' ) + elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: + from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb + optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) + if not self.fp16_enabled(): + logger.warning( + f'Currently the convergence of 1-bit Lamb is only verified under FP16' + ) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -710,6 +718,7 @@ class DeepSpeedEngine(Module): timers = self.timers if self.wall_clock_breakdown() else None optimizer = FP16_Optimizer( optimizer, + deepspeed=self, dynamic_loss_scale=True, initial_dynamic_scale=initial_dynamic_scale, dynamic_loss_args=dynamic_loss_args, @@ -723,6 +732,7 @@ class DeepSpeedEngine(Module): ranks=[0]) optimizer = FP16_Optimizer( optimizer, + deepspeed=self, static_loss_scale=self.loss_scale(), mpu=self.mpu, clip_grad=clip_grad, @@ -732,6 +742,7 @@ class DeepSpeedEngine(Module): ranks=[0]) optimizer = FP16_UnfusedOptimizer( optimizer, + deepspeed=self, static_loss_scale=self.loss_scale(), dynamic_loss_scale=self.dynamic_loss_scale(), dynamic_loss_args=dynamic_loss_args, diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index 5f35c1884a413635e029a8d8701dee3cf0d4fa9b..98275e5bb832df2a0f175e311bebbe566b474279 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -22,6 +22,7 @@ class FP16_Optimizer(object): """ def __init__(self, init_optimizer, + deepspeed=None, static_loss_scale=1.0, dynamic_loss_scale=False, initial_dynamic_scale=2**32, @@ -100,7 +101,9 @@ class FP16_Optimizer(object): self.mpu = mpu self.overflow = False - self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) + self.overflow_checker = CheckOverflow(self.fp16_groups, + mpu=self.mpu, + deepspeed=deepspeed) self.initialize_optimizer_states() def initialize_optimizer_states(self): diff --git a/deepspeed/runtime/fp16/onebit/adam.py b/deepspeed/runtime/fp16/onebit/adam.py index e3417fea9d6f49a2b4221c81e29c642e6cd974d5..35e35411cfde1795e82f2f9910620ee7f2005f9d 100644 --- a/deepspeed/runtime/fp16/onebit/adam.py +++ b/deepspeed/runtime/fp16/onebit/adam.py @@ -82,6 +82,7 @@ class OnebitAdam(torch.optim.Optimizer): self.initialize = False self.freeze_step = freeze_step self.cuda_aware = cuda_aware + self.using_pipeline = False self.comm_backend_name = comm_backend_name @@ -94,7 +95,9 @@ class OnebitAdam(torch.optim.Optimizer): assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" assert dist.is_initialized() == True, "Please initialize the torch distributed backend." from deepspeed.runtime.comm.nccl import NcclBackend - self.comm_backend_handle = NcclBackend() + self.using_pipeline = hasattr(self.deepspeed, + 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) elif self.comm_backend_name == 'mpi': from deepspeed.runtime.comm.mpi import MpiBackend @@ -254,8 +257,12 @@ class OnebitAdam(torch.optim.Optimizer): if self.adam_freeze_key is False: if state['step'] >= self.freeze_step: + print('OnebitAdam - starting compressed communication') self.adam_freeze_key = True - self.deepspeed.enable_backward_allreduce = False + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = False + else: + self.deepspeed.enable_backward_allreduce = False return loss @@ -277,18 +284,24 @@ class OnebitAdam(torch.optim.Optimizer): super().load_state_dict(state_dict) if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step: if torch.distributed.get_rank() == 0: - print("Checkpoint loaded and 1-bit Adam warmup stage starts/continues.") + print("Checkpoint loaded and OnebitAdam warmup stage starts/continues.") if self.adam_freeze_key is True: self.adam_freeze_key = False - self.deepspeed.enable_backward_allreduce = True + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = True + else: + self.deepspeed.enable_backward_allreduce = True else: if torch.distributed.get_rank() == 0: print( - "Checkpoint loaded and 1-bit Adam compression stage starts/continues." + "Checkpoint loaded and OnebitAdam compression stage starts/continues." ) if self.adam_freeze_key is False: self.adam_freeze_key = True - self.deepspeed.enable_backward_allreduce = False + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = False + else: + self.deepspeed.enable_backward_allreduce = False # We reset the compression errors when loading checkpoints for 3 reasons: # 1) The worker and server error at each GPU are distinct, so in current implementation # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. diff --git a/deepspeed/runtime/fp16/onebit/lamb.py b/deepspeed/runtime/fp16/onebit/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..01c6cd878488c73f0a08b6030622982e5fc45284 --- /dev/null +++ b/deepspeed/runtime/fp16/onebit/lamb.py @@ -0,0 +1,471 @@ +''' +Copyright 2021 The Microsoft DeepSpeed Team +''' +import types +import torch +import numpy as np +import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +class OnebitLamb(torch.optim.Optimizer): + """Implements the 1-bit Lamb algorithm. Currently GPU-only. + For usage example please see https://www.deepspeed.ai/tutorials/onebit-lamb/ + For technical details please see our paper https://arxiv.org/abs/2104.06069. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + freeze_step (int, optional): Number of steps for warmup (uncompressed) + stage before we start using compressed communication. (default 100000) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + max_coeff(float, optional): maximum value of the lamb coefficient (default: 10.0) + min_coeff(float, optional): minimum value of the lamb coefficient (default: 0.01) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in 1-bit Lamb! + eps_inside_sqrt (boolean, optional): in the 'update parameters' step, + adds eps to the bias-corrected second moment estimate before + evaluating square root instead of adding it to the square root of + second moment estimate as in the original paper. (default: False) + cuda_aware (boolean, required): Set True if the underlying MPI implementation + supports CUDA-Aware communication. (default: False) + comm_backend_name (string, optional): Set to 'mpi' if needed. (default: 'nccl') + coeff_beta (float, optional): coefficient used for computing + running averages of lamb coefficient (default: 0.9) note that you may want to + increase or decrease this beta depending on the freeze_step you choose, as + 1/(1 - coeff_beta) should be smaller than or equal to freeze_step + factor_max (float, optional): maximum value of scaling factor to the frozen lamb + coefficient during compression stage (default: 4.0) + factor_min (float, optional): minimum value of scaling factor to the frozen lamb + coefficient during compression stage (default: 0.5) + factor_threshold (float, optional): threshold of how much the scaling factor can + fluctuate between steps (default: 0.1) + .. _Large Batch Optimization for Deep Learning\: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + def __init__(self, + params, + deepspeed=None, + lr=1e-3, + freeze_step=100000, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + eps_inside_sqrt=False, + weight_decay=0., + max_grad_norm=0., + max_coeff=10.0, + min_coeff=0.01, + amsgrad=False, + cuda_aware=False, + comm_backend_name='nccl', + coeff_beta=0.9, + factor_max=4.0, + factor_min=0.5, + factor_threshold=0.1): + + if amsgrad: + raise RuntimeError('1-bit Lamb does not support the AMSGrad variant.') + + defaults = dict(lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + max_coeff=max_coeff, + min_coeff=min_coeff) + + super(OnebitLamb, self).__init__(params, defaults) + self.eps_mode = 0 if eps_inside_sqrt else 1 + assert (dist.is_initialized()) + + self.deepspeed = deepspeed + self.lamb_freeze_key = False + self.initialize = False + self.freeze_step = freeze_step + self.cuda_aware = cuda_aware + self.coeff_beta = coeff_beta + self.factor_max = factor_max + self.factor_min = factor_min + self.factor_threshold = factor_threshold + self.using_pipeline = False + + self.comm_backend_name = comm_backend_name + + # Empty initializer. Set handle based on the comm backend as follows. + self.comm_backend_handle = None + + if self.comm_backend_name == 'nccl': + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend" + assert dist.is_initialized() == True, "Please initialize the torch distributed backend." + from deepspeed.runtime.comm.nccl import NcclBackend + self.using_pipeline = hasattr(self.deepspeed, + 'pipeline_enable_backward_allreduce') + self.comm_backend_handle = NcclBackend(self.deepspeed.mpu) + + elif self.comm_backend_name == 'mpi': + from deepspeed.runtime.comm.mpi import MpiBackend + self.comm_backend_handle = MpiBackend(cuda_aware) + + self.size = self.comm_backend_handle.size + + self.divider = int(self.size * 8 / np.gcd(self.size, 8)) + + self.exp_avg_flat = [] + self.dummy_exp_avg = {} + self.corrected_tensor_sizes = [] + self.server_chunk_sizes = [] + self.worker_errors = [] + self.server_errors = [] + + self.lamb_coeffs = [] + + def step(self, closure=None, grads=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + grads (list of tensors, optional): weight gradient to use for the + optimizer update. If gradients have type torch.half, parameters + are expected to be in type torch.float. (default: None) + """ + loss = None + if closure is not None: + loss = closure() + + if grads is None: + grads_group = [None] * len(self.param_groups) + # backward compatibility + # assuming a list/generator of parameter means single group + elif isinstance(grads, types.GeneratorType): + grads_group = [grads] + elif type(grads[0]) != list: + grads_group = [grads] + else: + grads_group = grads + + #remove the previous stats + del self.lamb_coeffs[:] + + if self.lamb_freeze_key: + exp_avg_last_step = [] + for group in self.param_groups: + exp_avg_last_step.append( + [self.state[p]['exp_avg'].detach().clone() for p in group['params']]) + if 'scaling_coeff' not in self.state[self.param_groups[0]['params'][0]]: + # Compute the scaling_coeff for each momentum at the end of warmup stage. + # This is used to reduce compression error during compression stage. + momentum_scales = [] + for group in self.param_groups: + momentum_scales.append([ + (torch.norm(self.state[p]['exp_avg']) / + np.sqrt(torch.numel(self.state[p]['exp_avg']))).item() + for p in group['params'] + ]) + united_scale = sum([sum(x) for x in momentum_scales]) / sum( + [len(x) for x in momentum_scales]) + for i, group in enumerate(self.param_groups): + for j, p in enumerate(group['params']): + self.state[p][ + 'scaling_coeff'] = united_scale / momentum_scales[i][j] + + for group, grads_this_group in zip(self.param_groups, grads_group): + if grads_this_group is None: + grads_this_group = [None] * len(group['params']) + + bias_correction = 1 if group['bias_correction'] else 0 + + for p, grad in zip(group['params'], grads_this_group): + if p.grad is None and grad is None: + continue + if grad is None: + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('1-bit Lamb does not support sparse gradients') + + state = self.state[p] + + # State initialization + if len(state) == 0 or (len(state) == 1 + and 'scaling_coeff' in state.keys()): + state['step'] = 0 + state['lamb_coeff_freeze'] = 0.0 + state['last_factor'] = 1.0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + state['exp_avg_sq_fresh'] = torch.zeros_like(p.data) + + if not self.initialize: + self.lamb_freeze_key = True + + exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh'] + beta1, beta2 = group['betas'] + max_coeff = group['max_coeff'] + min_coeff = group['min_coeff'] + + state['step'] += 1 + + if self.lamb_freeze_key is False: + # warmup stage, baseline Lamb optimization + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if state['step'] == self.freeze_step: + exp_avg_sq_fresh.data = exp_avg_sq.detach().clone() + grad = None + if self.initialize: + weight_norm = p.data.pow(2).sum().sqrt() + update = exp_avg / (exp_avg_sq.sqrt() + group['eps']) + if group['weight_decay'] > 0.0: + update += group['weight_decay'] * p.data + update_norm = update.pow(2).sum().sqrt() + lamb_coeff = 1.0 + if weight_norm != 0 and update_norm != 0: + lamb_coeff = (weight_norm / update_norm).item() + if lamb_coeff > max_coeff: + lamb_coeff = max_coeff + if lamb_coeff < min_coeff: + lamb_coeff = min_coeff + if lamb_coeff != 1.0: + state['lamb_coeff_freeze'] = self.coeff_beta * state[ + 'lamb_coeff_freeze'] + (1 - self.coeff_beta) * lamb_coeff + self.lamb_coeffs.append(lamb_coeff) + with torch.no_grad(): + p.add_(-group['lr'] * lamb_coeff * update) + else: + # compression stage, update each momentum locally, then + # communicate based on the compressed_allreduce below + if self.initialize: + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg.mul_(self.state[p]['scaling_coeff']) + grad = None + + # init fused momentum + if len(self.exp_avg_flat) == 0: + momentum_groups = [] + tensor_size = 0 + for group in self.param_groups: + for p in group['params']: + momentum_groups.append(self.state[p]['exp_avg']) + tensor_size += torch.numel(p.data) + corrected_tensor_size = tensor_size + if tensor_size % (self.size * self.divider) != 0: + difference = ((self.size * self.divider) - (tensor_size % + (self.size * self.divider))) + corrected_tensor_size += difference + self.dummy_exp_avg[0] = torch.zeros( + difference, + device=momentum_groups[0].data.device) + momentum_groups.append(self.dummy_exp_avg[0]) + self.corrected_tensor_sizes.append(corrected_tensor_size) + self.server_chunk_sizes.append(corrected_tensor_size // self.size) + + self.exp_avg_flat.append( + _flatten_dense_tensors([p.detach().clone() for p in momentum_groups])) + updated_params = _unflatten_dense_tensors(self.exp_avg_flat[0], + momentum_groups) + for p, q in zip(momentum_groups, updated_params): + p.data = q.data + + if self.initialize and len(self.worker_errors) == 0: + torch.cuda.empty_cache() + for i in range(len(self.exp_avg_flat)): + self.worker_errors.append( + torch.zeros(self.corrected_tensor_sizes[i], + device=self.exp_avg_flat[i].device)) + self.server_errors.append( + torch.zeros(self.server_chunk_sizes[i], + device=self.exp_avg_flat[i].device)) + torch.cuda.empty_cache() + + if self.lamb_freeze_key: + if self.size > 1: + for i in range(len(self.exp_avg_flat)): + if not self.initialize: + torch.cuda.empty_cache() + self.worker_errors.append( + torch.zeros(self.corrected_tensor_sizes[i], + device=self.exp_avg_flat[i].device)) + self.server_errors.append( + torch.zeros(self.server_chunk_sizes[i], + device=self.exp_avg_flat[i].device)) + torch.cuda.empty_cache() + if torch.distributed.get_rank() == 0: + print("Cupy Buffers Initialized Successfully.") + + self.comm_backend_handle.compressed_allreduce( + self.exp_avg_flat[i], + self.worker_errors[0], + self.server_errors[0], + self.deepspeed.local_rank) + + if torch.distributed.get_rank() == 0: + print('Pop out errors', flush=True) + del self.worker_errors[:] + del self.server_errors[:] + else: + self.comm_backend_handle.compressed_allreduce( + self.exp_avg_flat[i], + self.worker_errors[i], + self.server_errors[i], + self.deepspeed.local_rank) + + if self.lamb_freeze_key and self.initialize: + for i, group in enumerate(self.param_groups): + bias_correction = 1 if group['bias_correction'] else 0 + + for j, p in enumerate(group['params']): + state = self.state[p] + exp_avg, exp_avg_sq, exp_avg_sq_fresh = state['exp_avg'], state['exp_avg_sq'], state['exp_avg_sq_fresh'] + beta1, beta2 = group['betas'] + exp_avg.div_(self.state[p]['scaling_coeff']) + # Because 1-bit compression cannot represent exact zero, it is required to + # provide a momentum mask for those params that have constant exact zeros in their + # momentums, otherwise the compression error would keep accumulating. + # For example, for BERT pre-training seq 128, bert.embeddings.position_embeddings.weight + # always have exact zeros in its momentum for row 129 to 512, because it only + # learns up to seq length 128 while the model supports up to 512 seq length. + # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py about how + # to add this exp_avg_mask for BERT pre-training.) + if 'exp_avg_mask' in group: + if exp_avg.device != group['exp_avg_mask'].device: + group['exp_avg_mask'] = group['exp_avg_mask'].to( + device=exp_avg.device) + exp_avg.mul_(group['exp_avg_mask']) + + grad_reconstruct = ((exp_avg - exp_avg_last_step[i][j] * beta1) / + (1 - beta1)) + exp_avg_sq_fresh.mul_(beta2).addcmul_(1 - beta2, + grad_reconstruct, + grad_reconstruct) + denom = exp_avg_sq.sqrt() + group['eps'] + update_prelim = exp_avg / denom + + if group['weight_decay'] > 0.0: + update = update_prelim + group['weight_decay'] * p.data + else: + update = update_prelim + + lamb_coeff = 1.0 + update_norm = update.pow(2).sum().sqrt() + denom_real = exp_avg_sq_fresh.sqrt() + group['eps'] + factor = (denom / denom_real).max().item() + if group['weight_decay'] > 0.0: + update_ratio = min(1.0, + (update_prelim.pow(2).sum().sqrt() / + update_norm).item()) + factor = factor * update_ratio + (1.0 - update_ratio) + if factor > self.factor_max: + factor = self.factor_max + if factor < self.factor_min: + factor = self.factor_min + if factor > state['last_factor'] * (1.0 + self.factor_threshold): + factor = state['last_factor'] * (1.0 + self.factor_threshold) + if factor < state['last_factor'] * (1.0 - self.factor_threshold): + factor = state['last_factor'] * (1.0 - self.factor_threshold) + state['last_factor'] = factor + lamb_coeff = state['lamb_coeff_freeze'] * factor + self.lamb_coeffs.append(lamb_coeff) + with torch.no_grad(): + p.add_(-group['lr'] * lamb_coeff * update) + del exp_avg_last_step[:] + exp_avg_last_step = None + + if not self.initialize: + self.lamb_freeze_key = False + self.initialize = True + print( + f"Finished the initialization step at rank {torch.distributed.get_rank()}" + ) + return loss + + if self.lamb_freeze_key is False: + if state['step'] >= self.freeze_step: + print('OnebitLamb - starting compressed communication') + self.lamb_freeze_key = True + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = False + else: + self.deepspeed.enable_backward_allreduce = False + + return loss + + def load_state_dict(self, state_dict): + """ + Overrides load_state_dict() to add special handling when loading checkpoints + """ + # Because at different stage exp_avg_mask may change (e.g., + # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask + # in checkpoints but always use the one user provided in training script. + # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.) + # Thus here we keep the exp_avg_mask unchanged when loading checkpoint + for i, group in enumerate(self.param_groups): + if 'exp_avg_mask' in group: + state_dict['param_groups'][i]['exp_avg_mask'] = group['exp_avg_mask'] + elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[ + 'param_groups'][i]: + state_dict['param_groups'][i].pop('exp_avg_mask') + super().load_state_dict(state_dict) + # need to reset the fused momentum since loading states will break the linking + del self.exp_avg_flat[:] + self.dummy_exp_avg.clear() + del self.corrected_tensor_sizes[:] + del self.server_chunk_sizes[:] + if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step: + if torch.distributed.get_rank() == 0: + print("Checkpoint loaded and OnebitLamb warmup stage starts/continues.") + if self.lamb_freeze_key is True: + self.lamb_freeze_key = False + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = True + else: + self.deepspeed.enable_backward_allreduce = True + for group in self.param_groups: + for p in group['params']: + self.state[p]['lamb_coeff_freeze'] = 0.0 + self.state[p]['last_factor'] = 1.0 + if 'scaling_coeff' in self.state[p]: + self.state[p].pop('scaling_coeff') + else: + if torch.distributed.get_rank() == 0: + print( + "Checkpoint loaded and OnebitLamb compression stage starts/continues." + ) + if self.lamb_freeze_key is False: + self.lamb_freeze_key = True + if self.using_pipeline: + self.deepspeed.pipeline_enable_backward_allreduce = False + else: + self.deepspeed.enable_backward_allreduce = False + # We reset the compression errors when loading checkpoints for 3 reasons: + # 1) The worker and server error at each GPU are distinct, so in current implementation + # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. + # If we want to save them correctly we need O(num_gpu*model_size) memory in order to + # gather all the error, which is a very large memory requirement. It's possible to save + # them in a distributed way, but it will make the checkpoint saving/loading much more complicated. + # 2) Even if we are able to save the compression errors correctly, you need to have the + # exact same number of GPUs in order to load them correctly. + # 3) We verified on BERT pre-training that occasionally resetting the compression error + # at checkpoint loading does not affect the convergence. + # However, please avoid frequent checkpoint loading which could break the error + # compensation mechanism thus affect the convergence. + del self.worker_errors[:] + del self.server_errors[:] + + def get_lamb_coeffs(self): + return self.lamb_coeffs diff --git a/deepspeed/runtime/fp16/unfused_optimizer.py b/deepspeed/runtime/fp16/unfused_optimizer.py index c0cef6a56ba7d691e259c3ff56e55950fb345051..441dbd61ccb9bf796caec362d853790de9c386fe 100755 --- a/deepspeed/runtime/fp16/unfused_optimizer.py +++ b/deepspeed/runtime/fp16/unfused_optimizer.py @@ -22,6 +22,7 @@ class FP16_UnfusedOptimizer(object): """ def __init__(self, init_optimizer, + deepspeed=None, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, @@ -96,7 +97,9 @@ class FP16_UnfusedOptimizer(object): self.mpu = mpu self.overflow = False - self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) + self.overflow_checker = CheckOverflow(self.fp16_groups, + mpu=self.mpu, + deepspeed=deepspeed) self.initialize_optimizer_states() diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1a401a27e36f81a8b69e8613094fe4ca23e13220..d4e5e5edfe7118d82f12c69597fd0ed2e5708799 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -56,6 +56,10 @@ class PipelineEngine(DeepSpeedEngine): # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False + + # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB + self.pipeline_enable_backward_allreduce = True + assert not self.elasticity_enabled(), "Elasticity is not currently supported" \ " with pipeline parallelism." @@ -222,7 +226,7 @@ class PipelineEngine(DeepSpeedEngine): def _exec_reduce_grads(self): self._force_grad_boundary = True - if self.is_data_parallel: + if self.is_data_parallel and self.pipeline_enable_backward_allreduce: self.buffered_allreduce_fallback( elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE) self._force_grad_boundary = False diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b1a7a4b0aae197fadb1eb5cc6a0138923c9acff1..d54613e196f5ad7b04e774dd36a25386add6cc44 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -64,10 +64,15 @@ def move_to_device(item, device): class CheckOverflow(object): '''Checks for overflow in gradient across parallel process''' - def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False): + def __init__(self, + param_groups=None, + mpu=None, + zero_reduce_scatter=False, + deepspeed=None): self.mpu = mpu self.params = [] if param_groups else None self.zero_reduce_scatter = zero_reduce_scatter + self.deepspeed = deepspeed if param_groups: for group in param_groups: for param in group: @@ -125,9 +130,24 @@ class CheckOverflow(object): op=torch.distributed.ReduceOp.MAX, group=torch.distributed.group.WORLD) elif self.mpu is not None: + if self.deepspeed is not None: + using_pipeline = hasattr(self.deepspeed, + 'pipeline_enable_backward_allreduce') + if (using_pipeline + and self.deepspeed.pipeline_enable_backward_allreduce is False + ) or (not using_pipeline + and self.deepspeed.enable_backward_allreduce is False): + torch.distributed.all_reduce( + overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.mpu.get_data_parallel_group()) torch.distributed.all_reduce(overflow_gpu, op=torch.distributed.ReduceOp.MAX, group=self.mpu.get_model_parallel_group()) + elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False: + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=torch.distributed.group.WORLD) overflow = overflow_gpu[0].item() return bool(overflow) diff --git a/docs/_config.yml b/docs/_config.yml index 19d679042b90538b384a852400efb6bf194ea8b2..a39298be04f995107d6b4a3d05a833cbdce05b97 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -33,15 +33,22 @@ collections: - advanced-install.md - getting-started.md - azure.md - - cifar-10.md - - bert-pretraining.md - bert-finetuning.md - - transformer_kernel.md + - bert-pretraining.md + - cifar-10.md + - flops-profiler.md + - gan.md + - lrrt.md - megatron.md - one-cycle.md - - lrrt.md + - onebit-adam.md + - onebit-lamb.md + - pipeline.md + - progressive_layer_dropping.md + - sparse-attention.md + - transformer_kernel.md + - zero-offload.md - zero.md - - flops-profiler.md defaults: - scope: diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 8b41df6a79f693481d20cc674af0619772aab087..6ab28bb84fd4593d83b81fca1c81aec3fab2483d 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -80,6 +80,8 @@ lnav: url: /tutorials/one-cycle/ - title: "One-Bit Adam" url: /tutorials/onebit-adam/ + - title: "One-Bit LAMB" + url: /tutorials/onebit-lamb/ - title: "Pipeline Parallelism" url: /tutorials/pipeline/ - title: "Progressive Layer Dropping" diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 1f59c29d2202bdf3e55ad232756b78e9660473b1..8d33179862efeaba6c153833829b2422dc6c1188 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -34,7 +34,7 @@ title: "DeepSpeed Configuration JSON" | Fields | Value | Example | | ------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- | -| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | +| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, and **OneBitLamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` | | params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` | Example of **optimizer** with Adam @@ -88,6 +88,42 @@ The 1-bit Adam optimizer supports the following three params keys/values in addi | cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false | | comm\_backend\_name | To indicate which backend implementation to use | "nccl" | +Another example of ***optimizer*** with 1-bit LAMB + +```json +"optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 11e-3, + "weight_decay": 0.01, + "bias_correction": false, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 1000, + "cuda_aware": false, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 4.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + } +``` + +The 1-bit LAMB optimizer supports the following params keys/values in addition to the standard LAMB (learn more in our [tutorial](/tutorials/onebit-lamb/)): + +| "params" key | Description | Default | +| ------------- | --------------------------------------------------------------------------- | ------- | +| max\_coeff | Scaling coefficient upper bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 10.0 | +| min\_coeff | Scaling coefficient lower bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 0.01 | +| freeze\_step | Number of warm up steps before 1-bit compression gets applied to the communication | 100000 | +| cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false | +| comm\_backend\_name | To indicate which backend implementation to use | "nccl" | +| coeff\_beta | Coefficient used for computing running averages of lamb coefficient | 0.9 | +| factor\_max | Maximum value of scaling factor to the frozen lamb coefficient during compression stage | 4.0 | +| factor\_min | Minimum value of scaling factor to the frozen lamb coefficient during compression stage | 0.5 | +| factor\_threshold | Threshold of how much the scaling factor can fluctuate between steps | 0.1 | + ### Scheduler Parameters diff --git a/docs/_pages/features.md b/docs/_pages/features.md index ba955fd574db3277612dde3c6b4e0a1a45f58e14..9b0b89d0a64b22778dfd1da19f6754ecad68a3ba 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -172,15 +172,17 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail ## Training Optimizers -### 1-bit Adam optimizer with up to 5x less communication +### 1-bit Adam and 1-bit LAMB optimizers with up to 5x less communication -DeepSpeed has an efficient implementation of a novel algorithm called 1-bit Adam. -It offers the same convergence as Adam, incurs up to 5x less communication that enables +DeepSpeed has two communication-efficient optimizers called 1-bit Adam and 1-bit LAMB. +They offer the same convergence as Adam/LAMB, incur up to 5x less communication that enables up to 3.5x higher throughput for BERT-Large pretraining and up to 2.7x higher throughput for SQuAD fine-tuning on bandwidth-limited clusters. For more details on usage and performance, -please refer to the detailed [tutorial](https://www.deepspeed.ai/tutorials/onebit-adam) and -[blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.md), respectively. - +please refer to the [1-bit Adam tutorial](https://www.deepspeed.ai/tutorials/onebit-adam), +[1-bit Adam blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.md), +and [1-bit LAMB tutorial](https://www.deepspeed.ai/tutorials/onebit-lamb/). For technical details, +please refer to the [1-bit Adam paper](https://arxiv.org/abs/2102.02888) and +[1-bit LAMB paper](https://arxiv.org/abs/2104.06069). ### Fused Adam optimizer and arbitrary torch.optim.Optimizer With DeepSpeed, the user can choose to use a high performance implementation of ADAM from diff --git a/docs/_tutorials/onebit-adam.md b/docs/_tutorials/onebit-adam.md index 8fba712937f8b71eca454df41d23ee10384921cf..feef716825139bf0945c315258915c71f009b7cb 100644 --- a/docs/_tutorials/onebit-adam.md +++ b/docs/_tutorials/onebit-adam.md @@ -7,7 +7,7 @@ This tutorial is updated on 03/04/2021 to reflect the 1-bit Adam v2. Changes inc {: .notice--info} **Watch out!** -1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently 1-bit Adam is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below. +1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently the MPI-based implementation is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below. {: .notice--warning} In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x. Detailed description of the 1-bit Adam algorithm, its implementation in DeepSpeed, and performance evaluation is available from our [blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html). We also have a [paper](https://arxiv.org/abs/2102.02888) which provides the most complete details including algorithm, system implementation, theoretical analysis, and more evaluations. @@ -23,7 +23,7 @@ For more details on these tasks, please refer to the tutorial posts on [BingBert ### 1.1 Pre-requisites for installing DeepSpeed -If you don't already have a copy of the DeepSpeed repository, please clone in +If you don't already have a copy of the DeepSpeed repository, please clone it now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples. ```shell @@ -106,7 +106,7 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_ Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. **Watch out!** -1-bit Adam replies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence. +1-bit Adam relies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence. {: .notice--warning} ## 2. BingBertSQuAD Fine-tuning with 1-bit Adam diff --git a/docs/_tutorials/onebit-lamb.md b/docs/_tutorials/onebit-lamb.md new file mode 100644 index 0000000000000000000000000000000000000000..f6d9341d90954609d7d2bbd3c83807248e4e8db0 --- /dev/null +++ b/docs/_tutorials/onebit-lamb.md @@ -0,0 +1,130 @@ +--- +title: "1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed" +--- + +**Watch out!** +1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit LAMB is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently the MPI-based implementation is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit LAMB's convergence. See details below. +{: .notice--warning} + +In this tutorial, we introduce DeepSpeed's 1-bit LAMB optimizer which enables communication-efficient large-scale large-batch training with LAMB's convergence speed. 1-bit LAMB can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 4.6x. We also have a [paper](https://arxiv.org/abs/2104.06069) which provides the technical details including algorithm, system implementation, and evaluations. + +To illustrate the benefits and usage of 1-bit LAMB optimizer, we use the BERT Pre-training task as example. For more details on this task, please refer to the [tutorial](/tutorials/bert-pretraining/). + +## 1. Overview + +### 1.1 Pre-requisites for installing DeepSpeed + +If you don't already have a copy of the DeepSpeed repository, please clone it +now and checkout the DeepSpeedExamples submodule that contains the BERT Pre-training example. + +```shell +git clone https://github.com/microsoft/DeepSpeed +cd DeepSpeed +git submodule update --init --recursive +cd DeepSpeedExamples/ +``` + +### 1.2 Pre-requisites for 1-bit LAMB + +#### 1.2.1 NCCL-based implementation + +In DeepSpeed, we introduce a system implementation for compressed communication using the NCCL backend of PyTorch distributed. This implementation provides better performance and usability than the MPI-based implementation below. Thus we highly recommend users to choose this implementation. + +**Watch out!** +This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0. +{: .notice--warning} + +#### 1.2.2 MPI-based implementation + +For this implementation, we rely on Message Passing Interface (MPI) for advanced communication primitives. + +We package the necessary dependencies in the DeepSpeed docker images. However, if you are using a different build system, please install MPI and mpi4py on your system. To install the prerequisites run: + +```shell +pip install deepspeed[1bit_adam] +``` + +We have tested CUDA-Aware MPI communication using the [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) library. However, any CUDA-Aware communication library including [OpenMPI](https://www.open-mpi.org/) should work fine with these examples. + +An example launch command for 1-bit LAMB using the `deepspeed` launcher is as follows: + +```shell +deepspeed --launcher=[mvapich|openmpi] script.py +``` + +Please note that for MPI-based implementation of 1-bit LAMB, the `--launcher=[mvapich|openmpi]` flag is required when using the `deepspeed` launcher. + +Alternatively, the standard mpirun launcher can also be used as follows: + +```shell +mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py] +``` + +### 1.3 1-bit LAMB Algorithm + +The detailed description of the 1-bit LAMB algorithm can be seen from our [paper](https://arxiv.org/abs/2104.06069). + +### 1.4 Configuration of 1-bit LAMB +The 1-bit LAMB feature can be used by setting the optimizer configuration options as follows. An example json config file is shown below. + +```json +{ + "train_batch_size": 65536, + "train_micro_batch_size_per_gpu": 64, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 11e-3, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 1000, + "cuda_aware": false, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 4.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "initial_scale_power": 16 + } +} +``` +Please note the new parameters `freeze_step`, `cuda_aware`, `comm_backend_name`, `coeff_beta`, `factor_max`, `factor_min`, and `factor_threshold` that have been added to support the 1-bit LAMB feature: + +`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model (This is related to LAMB's variance/second moment term and scaling coefficient. See detailed analysis in our [paper](https://arxiv.org/abs/2104.06069)). If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts. + +`cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication. + +`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`. + +`coeff_beta` is used when calculating a moving average of the LAMB scaling coefficient during the warmup stage. This moving average is then used as the frozen base scaling coefficient during the compression stage. + +`factor_max`, `factor_min`, and `factor_threshold` are used to regularize the adaptive scaling of the frozen base scaling coefficient during the compression stage. `factor_max` and `factor_min` are the scaling factor upper/lower bound. `factor_threshold` defines the threshold of how much the scaling factor can fluctuate between steps. + +#### 1.4.1 Momentum masks for parameters with constant zero gradients +Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit LAMB we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script. + +**Watch out!** +1-bit LAMB relies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence. +{: .notice--warning} + +## 2. BERT Pre-training with 1-bit LAMB +For data downloading and pre-processing, please refer to the [BERT Pre-training tutorial](/tutorials/bert-pretraining/). + +### 2.1 Running Pre-training with DeepSpeed and 1-bit LAMB + +We provide example scripts under [DeepSpeedExamples/bing_bert/1-bit_lamb/](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert/1-bit_lamb). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun. + +### 2.2 Configuration for BERT Pre-training with DeepSpeed and 1-bit LAMB enabled + +The `deepspeed_bsz64k_onebitlamb_config_seq128_*.json` and `deepspeed_bsz32k_onebitlamb_config_seq512_*.json` files give the user the ability to specify DeepSpeed +options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters. In these files we include the tuned hyperparameters to reproduce experiments in our [paper](https://arxiv.org/abs/2104.06069). + +### 2.3 Performance Results for BERT Pre-training + +Performance results can be seen in our [paper](https://arxiv.org/abs/2104.06069). diff --git a/docs/code-docs/source/optimizers.rst b/docs/code-docs/source/optimizers.rst index 04416486d9545d51b0fe591bed99ff6cac93765f..fda69e0677eba1066a557311b0006ecf099e3c52 100755 --- a/docs/code-docs/source/optimizers.rst +++ b/docs/code-docs/source/optimizers.rst @@ -1,27 +1,24 @@ -Optimizers -========== - -DeepSpeed offers high-performance implementations of ``Adam`` optimizer on CPU; ``FusedAdam``, ``FusedAdam``, ``OneBitAdam`` optimizers on GPU. - -Adam (CPU) ----------- - -.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam - - -FusedAdam (GPU) ---------------- - -.. autoclass:: deepspeed.ops.adam.FusedAdam - - -FusedLamb (GPU) ---------------- - -.. autoclass:: deepspeed.ops.lamb.FusedLamb - - -OneBitAdam (GPU) ----------------- - -.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OnebitAdam +Optimizers +=================== + +DeepSpeed offers high-performance implementations of ``Adam`` optimizer on CPU; ``FusedAdam``, ``FusedLamb``, ``OnebitAdam``, ``OnebitLamb`` optimizers on GPU. + +Adam (CPU) +---------------------------- +.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam + +FusedAdam (GPU) +---------------------------- +.. autoclass:: deepspeed.ops.adam.FusedAdam + +FusedLamb (GPU) +---------------------------- +.. autoclass:: deepspeed.ops.lamb.FusedLamb + +OneBitAdam (GPU) +---------------------------- +.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OnebitAdam + +OnebitLamb (GPU) +---------------------------- +.. autoclass:: deepspeed.runtime.fp16.onebit.lamb.OnebitLamb diff --git a/docs/index.md b/docs/index.md index ab6b1a0445d86687e98bd7af614b180fb77d8bc8..9d60ed6e129859c1d66bca440f7cc4499e743b27 100755 --- a/docs/index.md +++ b/docs/index.md @@ -17,7 +17,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis * Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters. * Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models. * Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers. -* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. +* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks. Early adopters of DeepSpeed have already produced a language model (LM) with over 17B parameters called @@ -30,6 +30,7 @@ initiative to enable next-generation AI capabilities at scale, where you can fin information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale). # What's New? +* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/) * [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) * [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/) * [2021/04/02] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed) @@ -134,7 +135,7 @@ combinations, which we call 3D parallelism. Pipeline parallelism of DeepSpeed reduce communication volume during distributed training, which allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. ![Low-bandwidth GPT-2 Performance](/assets/images/pp-lowbw-gpt2.png) -1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. [Read more here](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html). +1-bit Adam and 1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. [1-bit Adam blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html), [1-bit Adam tutorial](https://www.deepspeed.ai/tutorials/onebit-adam/), [1-bit LAMB tutorial](https://www.deepspeed.ai/tutorials/onebit-lamb/). ## Supporting long sequence length DeepSpeed offers sparse attention kernels—an instrumental technology to support long sequences of model inputs, whether for text, image, or sound. Compared with the classic dense Transformers, it powers **an order-of-magnitude longer input sequence** and obtains up to 6x faster execution with comparable accuracy. It also outperforms state-of-the-art sparse implementations with 1.5–3x faster execution. Furthermore, our sparse kernels support efficient execution of flexible sparse format and empower users to innovate on their custom sparse structures. [Read more here](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html). @@ -178,7 +179,7 @@ Below we provide a brief feature list, see our detailed [feature overview](https * Memory- and compute-efficient sparse kernels * Support 10x long sequences than dense * Flexible support to different sparse structures -* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) +* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/) * Custom communication collective * Up to 5x communication volume saving * [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations) @@ -235,6 +236,7 @@ comments. 4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840). 5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888). 6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857). +7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069). # Videos 1. DeepSpeed KDD 2020 Tutorial diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index 8e0056be0cff9d70871237b555477b4e3244e26b..9796a70953f8bd7b483123e93c82efa8088f6bf5 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -1,14 +1,22 @@ import torch +import torch.nn as nn +import torch.nn.functional as F import torch.distributed as dist import deepspeed import argparse import pytest +import copy import json import os import numpy as np import time + +from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology +PipeTopo = PipeDataParallelTopology +from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec from common import distributed_test from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args +from test_pipe import AlexNetPipe, train_cifar TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) @@ -241,9 +249,7 @@ def test_onebitadam_checkpointing(tmpdir): mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" save_folder = os.path.join(tmpdir, 'saved_checkpoint') - # optimizer_1.optimizer.gather_compression_errors() model_1.save_checkpoint(save_folder, tag=None) - time.sleep(5) assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" @@ -297,6 +303,552 @@ def test_onebitadam_checkpointing(tmpdir): hidden_dim=hidden_dim) +def test_onebitadam_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitadam_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitadam_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + +def test_onebitlamb_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp16_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_fp32_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp32_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.lamb_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + scaling_coeff_1 = [] + for v in optimizer_1.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_1.append(v['scaling_coeff']) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + model_1.save_checkpoint(save_folder, tag=None) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs is loaded correctly + scaling_coeff_2 = [] + for v in optimizer_2.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_2.append(v['scaling_coeff']) + assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" + assert optimizer_2.optimizer.lamb_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.lamb_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted + for v in optimizer_3.state.values(): + assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" + assert v['last_factor'] == 1.0, f"Incorrect last_factor" + assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" + assert optimizer_3.optimizer.lamb_freeze_key is False + + _test_onebitlamb_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitlamb_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitlamb_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + def test_compressed_allreduce_basic(tmpdir): @distributed_test(world_size=[1, 2]) def _test_compressed_allreduce_basic():