未验证 提交 67a48aaa 编写于 作者: C Conglong Li 提交者: GitHub

1-bit LAMB optimizer (#970)

1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed.
Author: @conglongli, @awan-10, @samyam, Hanlin Tang, Yuxiong He
Paper: https://arxiv.org/abs/2104.06069Co-authored-by: Nsdtblck <46172032+sdtblck@users.noreply.github.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 eecef309
......@@ -17,7 +17,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks.
Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
......@@ -33,6 +33,7 @@ information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale)
# News
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
* [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/)
* [2021/04/01] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed)
......@@ -119,7 +120,7 @@ overview](https://www.deepspeed.ai/features/) for descriptions and usage.
* Memory- and compute-efficient sparse kernels
* Support 10x longer sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* Custom communication collective
* Up to 5x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
......@@ -192,7 +193,7 @@ Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888).
6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857).
7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069).
# Videos
1. DeepSpeed KDD 2020 Tutorial
......
......@@ -12,8 +12,12 @@ from deepspeed.runtime.compression.cupy import CupyBackend
class NcclBackend(object):
def __init__(self):
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
def __init__(self, mpu=None):
if mpu is None:
self.world_group = dist.new_group(ranks=range(dist.get_world_size()))
else:
self.mpu = mpu
self.world_group = self.mpu.get_data_parallel_group()
self.rank = dist.get_rank(group=self.world_group)
self.size = dist.get_world_size(group=self.world_group)
self.compression_backend = CupyBackend()
......@@ -92,9 +96,11 @@ class NcclBackend(object):
# communication phase 1
# gather_start = time.time()
# Alltoall for sign
dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed))
dist.all_to_all_single(recvbuf_sign,
torch.stack(sign_list_packed),
group=self.world_group)
# Allgather for scale
dist.all_gather(recvbuf_scale, worker_scale)
dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)
# gather_end = time.time()
......@@ -151,8 +157,10 @@ class NcclBackend(object):
]
# Communication Phase 2
dist.all_gather(recvbuf_sign_server, server_sign_packed[0])
dist.all_gather(recvbuf_scale_server, server_scale)
dist.all_gather(recvbuf_sign_server,
server_sign_packed[0],
group=self.world_group)
dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group)
cupy_server_sign_packed = None
......
......@@ -32,11 +32,13 @@ ADAM_OPTIMIZER = 'adam'
ADAMW_OPTIMIZER = 'adamw'
LAMB_OPTIMIZER = 'lamb'
ONEBIT_ADAM_OPTIMIZER = 'onebitadam'
ONEBIT_LAMB_OPTIMIZER = 'onebitlamb'
DEEPSPEED_OPTIMIZERS = [
ADAM_OPTIMIZER,
ADAMW_OPTIMIZER,
LAMB_OPTIMIZER,
ONEBIT_ADAM_OPTIMIZER,
ONEBIT_LAMB_OPTIMIZER,
]
# extra optimizer parameters for adam/adamw
......
......@@ -24,7 +24,7 @@ from deepspeed.runtime.activation_checkpointing import checkpointing as activati
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, \
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
......@@ -553,7 +553,8 @@ class DeepSpeedEngine(Module):
assert self._is_supported_optimizer(self.optimizer_name()), \
'{} is not a supported DeepSpeed Optimizer'.format(self.optimizer_name())
if self.optimizer_name() == LAMB_OPTIMIZER:
if self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name(
) == ONEBIT_LAMB_OPTIMIZER:
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
......@@ -694,6 +695,13 @@ class DeepSpeedEngine(Module):
logger.warning(
f'Currently the convergence of 1-bit Adam is only verified under FP16'
)
elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER:
from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb
optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters)
if not self.fp16_enabled():
logger.warning(
f'Currently the convergence of 1-bit Lamb is only verified under FP16'
)
else:
torch_optimizer = getattr(torch.optim, self.optimizer_name())
optimizer = torch_optimizer(model_parameters, **optimizer_parameters)
......@@ -710,6 +718,7 @@ class DeepSpeedEngine(Module):
timers = self.timers if self.wall_clock_breakdown() else None
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
dynamic_loss_scale=True,
initial_dynamic_scale=initial_dynamic_scale,
dynamic_loss_args=dynamic_loss_args,
......@@ -723,6 +732,7 @@ class DeepSpeedEngine(Module):
ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
mpu=self.mpu,
clip_grad=clip_grad,
......@@ -732,6 +742,7 @@ class DeepSpeedEngine(Module):
ranks=[0])
optimizer = FP16_UnfusedOptimizer(
optimizer,
deepspeed=self,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=dynamic_loss_args,
......
......@@ -22,6 +22,7 @@ class FP16_Optimizer(object):
"""
def __init__(self,
init_optimizer,
deepspeed=None,
static_loss_scale=1.0,
dynamic_loss_scale=False,
initial_dynamic_scale=2**32,
......@@ -100,7 +101,9 @@ class FP16_Optimizer(object):
self.mpu = mpu
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
deepspeed=deepspeed)
self.initialize_optimizer_states()
def initialize_optimizer_states(self):
......
......@@ -82,6 +82,7 @@ class OnebitAdam(torch.optim.Optimizer):
self.initialize = False
self.freeze_step = freeze_step
self.cuda_aware = cuda_aware
self.using_pipeline = False
self.comm_backend_name = comm_backend_name
......@@ -94,7 +95,9 @@ class OnebitAdam(torch.optim.Optimizer):
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 8, "Please use torch 1.8 or greater to enable NCCL backend in 1-bit Adam. Alternatively, please specify 'mpi' as the 'comm_backend_name' in config file to proceed with the MPI backend"
assert dist.is_initialized() == True, "Please initialize the torch distributed backend."
from deepspeed.runtime.comm.nccl import NcclBackend
self.comm_backend_handle = NcclBackend()
self.using_pipeline = hasattr(self.deepspeed,
'pipeline_enable_backward_allreduce')
self.comm_backend_handle = NcclBackend(self.deepspeed.mpu)
elif self.comm_backend_name == 'mpi':
from deepspeed.runtime.comm.mpi import MpiBackend
......@@ -254,8 +257,12 @@ class OnebitAdam(torch.optim.Optimizer):
if self.adam_freeze_key is False:
if state['step'] >= self.freeze_step:
print('OnebitAdam - starting compressed communication')
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = False
else:
self.deepspeed.enable_backward_allreduce = False
return loss
......@@ -277,18 +284,24 @@ class OnebitAdam(torch.optim.Optimizer):
super().load_state_dict(state_dict)
if self.state[self.param_groups[0]['params'][0]]['step'] < self.freeze_step:
if torch.distributed.get_rank() == 0:
print("Checkpoint loaded and 1-bit Adam warmup stage starts/continues.")
print("Checkpoint loaded and OnebitAdam warmup stage starts/continues.")
if self.adam_freeze_key is True:
self.adam_freeze_key = False
self.deepspeed.enable_backward_allreduce = True
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = True
else:
self.deepspeed.enable_backward_allreduce = True
else:
if torch.distributed.get_rank() == 0:
print(
"Checkpoint loaded and 1-bit Adam compression stage starts/continues."
"Checkpoint loaded and OnebitAdam compression stage starts/continues."
)
if self.adam_freeze_key is False:
self.adam_freeze_key = True
self.deepspeed.enable_backward_allreduce = False
if self.using_pipeline:
self.deepspeed.pipeline_enable_backward_allreduce = False
else:
self.deepspeed.enable_backward_allreduce = False
# We reset the compression errors when loading checkpoints for 3 reasons:
# 1) The worker and server error at each GPU are distinct, so in current implementation
# only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
......
此差异已折叠。
......@@ -22,6 +22,7 @@ class FP16_UnfusedOptimizer(object):
"""
def __init__(self,
init_optimizer,
deepspeed=None,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None,
......@@ -96,7 +97,9 @@ class FP16_UnfusedOptimizer(object):
self.mpu = mpu
self.overflow = False
self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu)
self.overflow_checker = CheckOverflow(self.fp16_groups,
mpu=self.mpu,
deepspeed=deepspeed)
self.initialize_optimizer_states()
......
......@@ -56,6 +56,10 @@ class PipelineEngine(DeepSpeedEngine):
# We schedule the all-reduces, so disable it in super().backward()
self.enable_backward_allreduce = False
# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
self.pipeline_enable_backward_allreduce = True
assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with pipeline parallelism."
......@@ -222,7 +226,7 @@ class PipelineEngine(DeepSpeedEngine):
def _exec_reduce_grads(self):
self._force_grad_boundary = True
if self.is_data_parallel:
if self.is_data_parallel and self.pipeline_enable_backward_allreduce:
self.buffered_allreduce_fallback(
elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)
self._force_grad_boundary = False
......
......@@ -64,10 +64,15 @@ def move_to_device(item, device):
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
def __init__(self,
param_groups=None,
mpu=None,
zero_reduce_scatter=False,
deepspeed=None):
self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
self.deepspeed = deepspeed
if param_groups:
for group in param_groups:
for param in group:
......@@ -125,9 +130,24 @@ class CheckOverflow(object):
op=torch.distributed.ReduceOp.MAX,
group=torch.distributed.group.WORLD)
elif self.mpu is not None:
if self.deepspeed is not None:
using_pipeline = hasattr(self.deepspeed,
'pipeline_enable_backward_allreduce')
if (using_pipeline
and self.deepspeed.pipeline_enable_backward_allreduce is False
) or (not using_pipeline
and self.deepspeed.enable_backward_allreduce is False):
torch.distributed.all_reduce(
overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_data_parallel_group())
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=torch.distributed.group.WORLD)
overflow = overflow_gpu[0].item()
return bool(overflow)
......
......@@ -33,15 +33,22 @@ collections:
- advanced-install.md
- getting-started.md
- azure.md
- cifar-10.md
- bert-pretraining.md
- bert-finetuning.md
- transformer_kernel.md
- bert-pretraining.md
- cifar-10.md
- flops-profiler.md
- gan.md
- lrrt.md
- megatron.md
- one-cycle.md
- lrrt.md
- onebit-adam.md
- onebit-lamb.md
- pipeline.md
- progressive_layer_dropping.md
- sparse-attention.md
- transformer_kernel.md
- zero-offload.md
- zero.md
- flops-profiler.md
defaults:
- scope:
......
......@@ -80,6 +80,8 @@ lnav:
url: /tutorials/one-cycle/
- title: "One-Bit Adam"
url: /tutorials/onebit-adam/
- title: "One-Bit LAMB"
url: /tutorials/onebit-lamb/
- title: "Pipeline Parallelism"
url: /tutorials/pipeline/
- title: "Progressive Layer Dropping"
......
......@@ -34,7 +34,7 @@ title: "DeepSpeed Configuration JSON"
| Fields | Value | Example |
| ------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------- |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, and **Lamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, and **OneBitLamb** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |
Example of <i>**optimizer**</i> with Adam
......@@ -88,6 +88,42 @@ The 1-bit Adam optimizer supports the following three params keys/values in addi
| cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false |
| comm\_backend\_name | To indicate which backend implementation to use | "nccl" |
Another example of ***optimizer*** with 1-bit LAMB
```json
"optimizer": {
"type": "OneBitLamb",
"params": {
"lr": 11e-3,
"weight_decay": 0.01,
"bias_correction": false,
"max_coeff": 0.3,
"min_coeff": 0.01,
"freeze_step": 1000,
"cuda_aware": false,
"comm_backend_name": "nccl",
"coeff_beta": 0.9,
"factor_max": 4.0,
"factor_min": 0.5,
"factor_threshold": 0.1
}
}
```
The 1-bit LAMB optimizer supports the following params keys/values in addition to the standard LAMB (learn more in our [tutorial](/tutorials/onebit-lamb/)):
| "params" key | Description | Default |
| ------------- | --------------------------------------------------------------------------- | ------- |
| max\_coeff | Scaling coefficient upper bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 10.0 |
| min\_coeff | Scaling coefficient lower bound for original LAMB algorithm and 1-bit LAMB's warmup stage | 0.01 |
| freeze\_step | Number of warm up steps before 1-bit compression gets applied to the communication | 100000 |
| cuda\_aware | To indicate that the underlying MPI library supports CUDA-Aware communication | false |
| comm\_backend\_name | To indicate which backend implementation to use | "nccl" |
| coeff\_beta | Coefficient used for computing running averages of lamb coefficient | 0.9 |
| factor\_max | Maximum value of scaling factor to the frozen lamb coefficient during compression stage | 4.0 |
| factor\_min | Minimum value of scaling factor to the frozen lamb coefficient during compression stage | 0.5 |
| factor\_threshold | Threshold of how much the scaling factor can fluctuate between steps | 0.1 |
### Scheduler Parameters
......
......@@ -172,15 +172,17 @@ Please see the [core API doc](https://deepspeed.readthedocs.io/) for more detail
## Training Optimizers
### 1-bit Adam optimizer with up to 5x less communication
### 1-bit Adam and 1-bit LAMB optimizers with up to 5x less communication
DeepSpeed has an efficient implementation of a novel algorithm called 1-bit Adam.
It offers the same convergence as Adam, incurs up to 5x less communication that enables
DeepSpeed has two communication-efficient optimizers called 1-bit Adam and 1-bit LAMB.
They offer the same convergence as Adam/LAMB, incur up to 5x less communication that enables
up to 3.5x higher throughput for BERT-Large pretraining and up to 2.7x higher throughput
for SQuAD fine-tuning on bandwidth-limited clusters. For more details on usage and performance,
please refer to the detailed [tutorial](https://www.deepspeed.ai/tutorials/onebit-adam) and
[blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.md), respectively.
<!-- **TODO: add paper link when it is ready ** -->
please refer to the [1-bit Adam tutorial](https://www.deepspeed.ai/tutorials/onebit-adam),
[1-bit Adam blog post](https://www.deepspeed.ai/news/2020/09/09/onebit-adam-blog-post.md),
and [1-bit LAMB tutorial](https://www.deepspeed.ai/tutorials/onebit-lamb/). For technical details,
please refer to the [1-bit Adam paper](https://arxiv.org/abs/2102.02888) and
[1-bit LAMB paper](https://arxiv.org/abs/2104.06069).
### Fused Adam optimizer and arbitrary torch.optim.Optimizer
With DeepSpeed, the user can choose to use a high performance implementation of ADAM from
......
......@@ -7,7 +7,7 @@ This tutorial is updated on 03/04/2021 to reflect the 1-bit Adam v2. Changes inc
{: .notice--info}
**Watch out!**
1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently 1-bit Adam is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below.
1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit Adam is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently the MPI-based implementation is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit Adam's convergence. See details below.
{: .notice--warning}
In this tutorial, we are going to introduce the 1-bit Adam optimizer in DeepSpeed. 1-bit Adam can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 5x. Detailed description of the 1-bit Adam algorithm, its implementation in DeepSpeed, and performance evaluation is available from our [blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html). We also have a [paper](https://arxiv.org/abs/2102.02888) which provides the most complete details including algorithm, system implementation, theoretical analysis, and more evaluations.
......@@ -23,7 +23,7 @@ For more details on these tasks, please refer to the tutorial posts on [BingBert
### 1.1 Pre-requisites for installing DeepSpeed
If you don't already have a copy of the DeepSpeed repository, please clone in
If you don't already have a copy of the DeepSpeed repository, please clone it
now and checkout the DeepSpeedExamples submodule that contains the BingBertSQuAD and BERT Pre-training examples.
```shell
......@@ -106,7 +106,7 @@ Please note three new parameters `freeze_step`, `cuda_aware`, and `comm_backend_
Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit Adam v2 we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script.
**Watch out!**
1-bit Adam replies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence.
1-bit Adam relies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence.
{: .notice--warning}
## 2. BingBertSQuAD Fine-tuning with 1-bit Adam
......
---
title: "1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed"
---
**Watch out!**
1) The NCCL-based implementation requires PyTorch >= 1.8 (and NCCL >= 2.8.3 when you have 64 or more GPUs). See details below. 2) Although 1-bit LAMB is compatible with both FP16 and FP32, currently we only verified the convergence under mixed precision/FP16 training. 3) Currently the MPI-based implementation is not compatible with pipeline parallelism. 4) Frequent checkpoint loading could hurt 1-bit LAMB's convergence. See details below.
{: .notice--warning}
In this tutorial, we introduce DeepSpeed's 1-bit LAMB optimizer which enables communication-efficient large-scale large-batch training with LAMB's convergence speed. 1-bit LAMB can improve model training speed on communication-constrained clusters, especially for communication-intensive large models by reducing the overall communication volume by up to 4.6x. We also have a [paper](https://arxiv.org/abs/2104.06069) which provides the technical details including algorithm, system implementation, and evaluations.
To illustrate the benefits and usage of 1-bit LAMB optimizer, we use the BERT Pre-training task as example. For more details on this task, please refer to the [tutorial](/tutorials/bert-pretraining/).
## 1. Overview
### 1.1 Pre-requisites for installing DeepSpeed
If you don't already have a copy of the DeepSpeed repository, please clone it
now and checkout the DeepSpeedExamples submodule that contains the BERT Pre-training example.
```shell
git clone https://github.com/microsoft/DeepSpeed
cd DeepSpeed
git submodule update --init --recursive
cd DeepSpeedExamples/
```
### 1.2 Pre-requisites for 1-bit LAMB
#### 1.2.1 NCCL-based implementation
In DeepSpeed, we introduce a system implementation for compressed communication using the NCCL backend of PyTorch distributed. This implementation provides better performance and usability than the MPI-based implementation below. Thus we highly recommend users to choose this implementation.
**Watch out!**
This NCCL-based implementation requires PyTorch >= 1.8. It also requires NCCL >= 2.8.3 when you have 64 or more GPUs to avoid certain NCCL runtime bugs. Currently (2021/03/16) NCCL 2.8.3 is not officially supported by PyTorch. The solution we used is by hacking in NCCL 2.8.3 via `LD_PRELOAD`: 1) Install NCCL 2.8.3. This works for us on a CUDA 11 system: `apt-get install -y libnccl2=2.8.3-1+cuda11.0 libnccl-dev=2.8.3-1+cuda11.0`. 2) Set `LD_PRELOAD` to the the library path. This works for us: `LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libnccl.so.2.8.3`. To confirm `LD_PRELOAD` is working you can see the version it uses in the NCCL logs if you have `NCCL_DEBUG=INFO`, it should say: NCCL version 2.8.3+cuda11.0.
{: .notice--warning}
#### 1.2.2 MPI-based implementation
For this implementation, we rely on Message Passing Interface (MPI) for advanced communication primitives.
We package the necessary dependencies in the DeepSpeed docker images. However, if you are using a different build system, please install MPI and mpi4py on your system. To install the prerequisites run:
```shell
pip install deepspeed[1bit_adam]
```
We have tested CUDA-Aware MPI communication using the [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) library. However, any CUDA-Aware communication library including [OpenMPI](https://www.open-mpi.org/) should work fine with these examples.
An example launch command for 1-bit LAMB using the `deepspeed` launcher is as follows:
```shell
deepspeed --launcher=[mvapich|openmpi] script.py
```
Please note that for MPI-based implementation of 1-bit LAMB, the `--launcher=[mvapich|openmpi]` flag is required when using the `deepspeed` launcher.
Alternatively, the standard mpirun launcher can also be used as follows:
```shell
mpirun -np [num processes] -ppn [num GPUs on each node] -hostfile [hostfile] [MPI flags] python [training_script.py]
```
### 1.3 1-bit LAMB Algorithm
The detailed description of the 1-bit LAMB algorithm can be seen from our [paper](https://arxiv.org/abs/2104.06069).
### 1.4 Configuration of 1-bit LAMB
The 1-bit LAMB feature can be used by setting the optimizer configuration options as follows. An example json config file is shown below.
```json
{
"train_batch_size": 65536,
"train_micro_batch_size_per_gpu": 64,
"optimizer": {
"type": "OneBitLamb",
"params": {
"lr": 11e-3,
"max_coeff": 0.3,
"min_coeff": 0.01,
"freeze_step": 1000,
"cuda_aware": false,
"comm_backend_name": "nccl",
"coeff_beta": 0.9,
"factor_max": 4.0,
"factor_min": 0.5,
"factor_threshold": 0.1
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}
```
Please note the new parameters `freeze_step`, `cuda_aware`, `comm_backend_name`, `coeff_beta`, `factor_max`, `factor_min`, and `factor_threshold` that have been added to support the 1-bit LAMB feature:
`freeze_step` is the number of warm up steps before 1-bit compression gets applied to the communication. In order to determine the number of warm up steps, one strategy is to set 15-25% of the total training steps for a given model (This is related to LAMB's variance/second moment term and scaling coefficient. See detailed analysis in our [paper](https://arxiv.org/abs/2104.06069)). If it provides the desired outcome, one can try to extract more performance by reducing the steps systematically. In future, we plan to introduce a threshold that can automatically search and decide for the number of warm up steps for different models. The examples below have been tuned for the number of warm up steps. The `freeze_step` parameter has already been set to the best number we found in the corresponding run scripts.
`cuda_aware` is used for MPI-based implementation to indicate that the underlying MPI library supports CUDA-Aware communication. This feature is only supported on systems with InfiniBand interconnect and a CUDA-Aware MPI library like [MVAPICH2-GDR](http://mvapich.cse.ohio-state.edu/userguide/gdr/) or OpenMPI built with CUDA-Aware support. Setting `cuda_aware` to False will allow training on Ethernet based systems. However, the communication will happen using sender as well as receiver side memory copies between CPU and GPU buffers before and after communication.
`comm_backend_name` is used to indicate which backend implementation to use. You can choose between NCCL and MPI-based implementations by setting `comm_backend_name` to "nccl" or "mpi". When using NCCL-based implementation, there is no need to set `cuda_aware`.
`coeff_beta` is used when calculating a moving average of the LAMB scaling coefficient during the warmup stage. This moving average is then used as the frozen base scaling coefficient during the compression stage.
`factor_max`, `factor_min`, and `factor_threshold` are used to regularize the adaptive scaling of the frozen base scaling coefficient during the compression stage. `factor_max` and `factor_min` are the scaling factor upper/lower bound. `factor_threshold` defines the threshold of how much the scaling factor can fluctuate between steps.
#### 1.4.1 Momentum masks for parameters with constant zero gradients
Because 1-bit compression cannot represent exact zero, the compression error would keep accumulating in the momentum if a parameter have constant zero gradients during training. For example, for BERT pre-training seq length 128, `bert.embeddings.position_embeddings.weight` has constant zeros in its gradient and momentum for row 129 to 512, because it only learns up to seq length 128 while the model supports up to seq length 512. Thus in 1-bit LAMB we added support of a momentum mask for users to specify those params that have constant exact zeros in their gradients. See [example script](https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/deepspeed_train.py) for how to configure this momentum mask. One thing to note is that we don't use momentum mask saved in checkpoints since this mask could change during training (e.g., BERT seqlen 128 and 512 require different masks). So you have to provide this mask every time in your training script.
**Watch out!**
1-bit LAMB relies on an compression error compensation mechanism to maintain the convergence speed at compression stage. When loading checkpoints, we actually reset the compression errors for 3 reasons: 1) The worker and server error at each GPU are distinct, so in current implementation only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. If we want to save them correctly we need O(num_gpu*model_size) memory in order to gather all the error, which is a very large memory requirement. It's possible to save them in a distributed way, but it will make the checkpoint saving/loading much more complicated. 2) Even if we are able to save the compression errors correctly, you need to have the exact same number of GPUs in order to load them correctly. 3) We verified on BERT pre-training that occasionally resetting the compression error at checkpoint loading does not affect the convergence. However, please avoid frequent checkpoint loading which could break the error compensation mechanism thus affect the convergence.
{: .notice--warning}
## 2. BERT Pre-training with 1-bit LAMB
For data downloading and pre-processing, please refer to the [BERT Pre-training tutorial](/tutorials/bert-pretraining/).
### 2.1 Running Pre-training with DeepSpeed and 1-bit LAMB
We provide example scripts under [DeepSpeedExamples/bing_bert/1-bit_lamb/](https://github.com/microsoft/DeepSpeedExamples/tree/master/bing_bert/1-bit_lamb). There are 3 sets of scripts corresponding to NCCL-based implementation, MPI-based implementation on Ethernet systems, and MPI-based implementation on InfiniBand systems. For MPI-based implementation, we provide both example scripts when launching with deepspeed or mpirun.
### 2.2 Configuration for BERT Pre-training with DeepSpeed and 1-bit LAMB enabled
The `deepspeed_bsz64k_onebitlamb_config_seq128_*.json` and `deepspeed_bsz32k_onebitlamb_config_seq512_*.json` files give the user the ability to specify DeepSpeed
options in terms of batch size, micro batch size, optimizer, learning rate, and other parameters. In these files we include the tuned hyperparameters to reproduce experiments in our [paper](https://arxiv.org/abs/2104.06069).
### 2.3 Performance Results for BERT Pre-training
Performance results can be seen in our [paper](https://arxiv.org/abs/2104.06069).
Optimizers
==========
DeepSpeed offers high-performance implementations of ``Adam`` optimizer on CPU; ``FusedAdam``, ``FusedAdam``, ``OneBitAdam`` optimizers on GPU.
Adam (CPU)
----------
.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam
FusedAdam (GPU)
---------------
.. autoclass:: deepspeed.ops.adam.FusedAdam
FusedLamb (GPU)
---------------
.. autoclass:: deepspeed.ops.lamb.FusedLamb
OneBitAdam (GPU)
----------------
.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OnebitAdam
Optimizers
===================
DeepSpeed offers high-performance implementations of ``Adam`` optimizer on CPU; ``FusedAdam``, ``FusedLamb``, ``OnebitAdam``, ``OnebitLamb`` optimizers on GPU.
Adam (CPU)
----------------------------
.. autoclass:: deepspeed.ops.adam.DeepSpeedCPUAdam
FusedAdam (GPU)
----------------------------
.. autoclass:: deepspeed.ops.adam.FusedAdam
FusedLamb (GPU)
----------------------------
.. autoclass:: deepspeed.ops.lamb.FusedLamb
OneBitAdam (GPU)
----------------------------
.. autoclass:: deepspeed.runtime.fp16.onebit.adam.OnebitAdam
OnebitLamb (GPU)
----------------------------
.. autoclass:: deepspeed.runtime.fp16.onebit.lamb.OnebitLamb
......@@ -17,7 +17,7 @@ DeepSpeed delivers extreme-scale model training for everyone, from data scientis
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam/1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam/LAMB, allowing for scaling to different types of GPU clusters and networks.
Early adopters of DeepSpeed have already produced
a language model (LM) with over 17B parameters called
......@@ -30,6 +30,7 @@ initiative to enable next-generation AI capabilities at scale, where you can fin
information [here](https://innovation.microsoft.com/en-us/exploring-ai-at-scale).
# What's New?
* [2021/04/20] [1-bit LAMB: up to 4.6x less communication and 2.8x faster training, together with LAMB's convergence speed at large batch sizes](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* [2021/04/19] [ZeRO-Infinity unlocks unprecedented model scale for deep learning training](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/)
* [Tutorial on how to use different stages of ZeRO](https://www.deepspeed.ai/tutorials/zero/)
* [2021/04/02] [[DeepSpeed on AzureML] Transformers and CIFAR examples are now available on AzureML GitHub](https://github.com/Azure/azureml-examples/tree/main/workflows/train/deepspeed)
......@@ -134,7 +135,7 @@ combinations, which we call 3D parallelism.
Pipeline parallelism of DeepSpeed reduce communication volume during distributed training, which allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth.
![Low-bandwidth GPT-2 Performance](/assets/images/pp-lowbw-gpt2.png)
1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. [Read more here](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html).
1-bit Adam and 1-bit LAMB reduce communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks. [1-bit Adam blog post](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html), [1-bit Adam tutorial](https://www.deepspeed.ai/tutorials/onebit-adam/), [1-bit LAMB tutorial](https://www.deepspeed.ai/tutorials/onebit-lamb/).
## Supporting long sequence length
DeepSpeed offers sparse attention kernels—an instrumental technology to support long sequences of model inputs, whether for text, image, or sound. Compared with the classic dense Transformers, it powers **an order-of-magnitude longer input sequence** and obtains up to 6x faster execution with comparable accuracy. It also outperforms state-of-the-art sparse implementations with 1.5–3x faster execution. Furthermore, our sparse kernels support efficient execution of flexible sparse format and empower users to innovate on their custom sparse structures. [Read more here](https://www.deepspeed.ai/news/2020/09/08/sparse-attention.html).
......@@ -178,7 +179,7 @@ Below we provide a brief feature list, see our detailed [feature overview](https
* Memory- and compute-efficient sparse kernels
* Support 10x long sequences than dense
* Flexible support to different sparse structures
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html)
* [1-bit Adam](https://www.deepspeed.ai/news/2020/09/08/onebit-adam-blog-post.html) and [1-bit LAMB](https://www.deepspeed.ai/tutorials/onebit-lamb/)
* Custom communication collective
* Up to 5x communication volume saving
* [Additional Memory and Bandwidth Optimizations](https://www.deepspeed.ai/features/#additional-memory-and-bandwidth-optimizations)
......@@ -235,6 +236,7 @@ comments.
4. Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. (2021) ZeRO-Offload: Democratizing Billion-Scale Model Training. [arXiv:2101.06840](https://arxiv.org/abs/2101.06840).
5. Hanlin Tang, Shaoduo Gan, Ammar Ahmad Awan, Samyam Rajbhandari, Conglong Li, Xiangru Lian, Ji Liu, Ce Zhang, Yuxiong He. (2021) 1-bit Adam: Communication Efficient Large-Scale Training with Adam's Convergence Speed. [arXiv:2102.02888](https://arxiv.org/abs/2102.02888).
6. Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He. (2021) ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning. [arXiv:2104.07857](https://arxiv.org/abs/2104.07857).
7. Conglong Li, Ammar Ahmad Awan, Hanlin Tang, Samyam Rajbhandari, Yuxiong He. (2021) 1-bit LAMB: Communication Efficient Large-Scale Large-Batch Training with LAMB's Convergence Speed. [arXiv:2104.06069](https://arxiv.org/abs/2104.06069).
# Videos
1. DeepSpeed KDD 2020 Tutorial
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册