未验证 提交 be1147c0 编写于 作者: O Olatunji Ruwase 提交者: GitHub

PLD release (#513)

* Progressive layer dropping docs (#499)

* test

* Adding tutorial and news page for pld

* updating the tutorial and posts of PLD

* update the finetune tutorial

* Update PLD tutorial (#512)

* Update installation instructions

* Format fix

* ZeRO tutorial

* Format fixes

* ZeRO-Offload

* ZeRO and ZeRO-Offload tutorials

* Update navigation page

* Format fixes

* Add yuxhe feedback

* Fix blog post link

* Fix OneBit-Adam link
Tweak scheduler example

* Fix date link

* Add DeepSpeed_Adam

* Add PLD tutorial to navigation
Co-authored-by: NShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>

* updating the pld docs

* DeepSpeed implementation of PLD (#508)

* DeepSpeed implementation of PLD

* Format fixes

* Formatting fixes

* Fix broken url

* Address PR feedback

* Bump DSE
Co-authored-by: NMinjia Zhang <33713995+minjiaz@users.noreply.github.com>
Co-authored-by: NShaden Smith <Shaden.Smith@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
Co-authored-by: NMinjia Zhang <minjiaz@microsoft.com>
上级 e082d475
Subproject commit a79272cc8b8f0c5b66c803e581a1355341eacb77
Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7
......@@ -155,8 +155,7 @@ all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the
[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or
comments.
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
# Publications
1. Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. (2019) ZeRO: Memory Optimization Towards Training A Trillion Parameter Models. [ArXiv:1910.02054](https://arxiv.org/abs/1910.02054)
......@@ -30,6 +30,24 @@ TORCH_ADAM_PARAM = "torch_adam"
ADAM_W_MODE_PARAM = "adam_w_mode"
def get_pld_enabled(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP],
PLD_ENABLED,
PLD_ENABLED_DEFAULT)
else:
return False
def get_pld_params(param_dict):
if PROGRESSIVE_LAYER_DROP in param_dict.keys():
pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP])
pld_params.pop(PLD_ENABLED)
return pld_params
else:
return False
def get_amp_enabled(param_dict):
if AMP in param_dict.keys():
return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT)
......@@ -542,6 +560,9 @@ class DeepSpeedConfig(object):
self.sparse_attention = get_sparse_attention(param_dict)
self.pipeline = get_pipeline_config(param_dict)
self.pld_enabled = get_pld_enabled(param_dict)
self.pld_params = get_pld_params(param_dict)
def _batch_assertion(self):
train_batch = self.train_batch_size
......
......@@ -291,3 +291,16 @@ TENSORBOARD_OUTPUT_PATH_DEFAULT = ""
# Tensorboard job name
TENSORBOARD_JOB_NAME = "job_name"
TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName"
# Progressive Layer Drop (PLD)
PROGRESSIVE_LAYER_DROP = "progressive_layer_drop"
# PLD enable signal
PLD_ENABLED = "enabled"
PLD_ENABLED_DEFAULT = False
PLD_THETA = "theta"
PLD_THETA_DEFAULT = 1.0
PLD_GAMMA = "gamma"
PLD_GAMMA_DEFAULT = 0.001
......@@ -26,13 +26,14 @@ from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
from deepspeed.runtime.dataloader import DeepSpeedDataLoader
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
TORCH_DISTRIBUTED_DEFAULT_PORT
TORCH_DISTRIBUTED_DEFAULT_PORT, PLD_THETA, PLD_GAMMA
from deepspeed.runtime.zero.constants import \
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from .utils import ensure_directory_exists
......@@ -127,6 +128,7 @@ class DeepSpeedEngine(Module):
self.loaded_checkpoint_mp_world_size = None
self.loaded_checkpoint_dp_world_size = None
self.enable_backward_allreduce = True
self.progressive_layer_drop = None
if dist_init_required is None:
dist_init_required = not dist.is_initialized()
......@@ -192,10 +194,13 @@ class DeepSpeedEngine(Module):
self.save_zero_checkpoint = False
self._configure_checkpointing(dist_init_required)
if self.pld_enabled():
self.progressive_layer_drop = self._configure_progressive_layer_drop()
if self.global_rank == 0:
self._config.print('DeepSpeedLight configuration')
self._config.print('DeepSpeedEngine configuration')
if self.dump_state():
print_configuration(self, 'DeepSpeedLight')
print_configuration(self, 'DeepSpeedEngine')
def _mpi_check(self, args, dist_init_required):
if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi:
......@@ -236,6 +241,18 @@ class DeepSpeedEngine(Module):
assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
world_size, dist.get_world_size())
def pld_enabled(self):
return self._config.pld_enabled
def pld_params(self):
return self._config.pld_params
def pld_theta(self):
return self.pld_params()[PLD_THETA]
def pld_gamma(self):
return self.pld_params()[PLD_GAMMA]
def tensorboard_enabled(self):
return self._config.tensorboard_enabled
......@@ -666,6 +683,11 @@ class DeepSpeedEngine(Module):
return optimizer
def _configure_progressive_layer_drop(self):
pld = ProgressiveLayerDrop(theta=self.pld_theta(), gamma=self.pld_gamma())
return pld
def deepspeed_io(self,
dataset,
batch_size=None,
......@@ -751,6 +773,9 @@ class DeepSpeedEngine(Module):
**kwargs: variable length keyword arguments
"""
if self.module.training and self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())
if self.wall_clock_breakdown():
self.timers('forward_microstep').start()
self.timers('forward').start()
......@@ -931,6 +956,9 @@ class DeepSpeedEngine(Module):
# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
if self.progressive_layer_drop:
self.progressive_layer_drop.update_state(self.global_steps)
self._take_model_step()
self.tput_timer.stop(report_progress)
......@@ -1024,6 +1052,12 @@ class DeepSpeedEngine(Module):
else:
return self._get_optimizer_param('betas')
def get_pld_theta(self):
if self.progressive_layer_drop:
return self.progressive_layer_drop.get_theta()
else:
return None
def _report_progress(self, step):
lr = self.get_lr()
mom = self.get_mom()
......
import numpy as np
from deepspeed.utils import log_dist
class ProgressiveLayerDrop(object):
r""" Progressive Layer Dropping (PLD) for model training.
This implements the PLD technique for compressed model training
from this paper: https://arxiv.org/pdf/2010.13369.pdf
Args:
theta (float): a hyper-parameter that controls the trade-off between training time and robustness.
The lower the theta value, the faster the training speed. Default value: 0.5.
gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001.
"""
def __init__(self, theta=0.5, gamma=0.001):
super().__init__()
self.theta = theta
self.gamma = gamma
self.current_theta = 1.0
log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0])
def get_state(self):
kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()}
return kwargs
def get_theta(self):
return self.current_theta
def update_state(self, global_step):
def _prob(x, gamma, p):
return (1. - p) * np.exp(-gamma * x) + p
self.current_theta = _prob(global_step, self.gamma, self.theta)
......@@ -7,7 +7,7 @@ new_post: true
date: 2020-10-29 00:00:00
---
We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources.
We introduce a new technology called progressive layer dropping (PLD) to speedup the pre-training of Transformer-based networks through efficient and robust compressed training. The pre-training step of Transformer networks often suffer from unbearable overall computational expenses. We analyze the training dynamics and stability of Transformer networks and propose PLD to sparsely update Transformer blocks following a progressive dropping schedule, which smoothly increases the layer dropping rate for each mini-batch as training evolves along both the temporal and the model depth dimension. PLD is able to allow the pre-training to be **2.5X faster** to get similar accuracy on downstream tasks and allows the training to be **24% faster** when training the same number of samples, not at the cost of excessive hardware resources.
* For detailed technology deep dive, see our [technical report](https://arxiv.org/pdf/2010.13369.pdf).
* For more information on how to use PLD, see our [Progressive layer dropping tutorial](https://www.deepspeed.ai/tutorials/progressive_layer_dropping/).
......
......@@ -18,7 +18,7 @@ already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_dr
bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh
```
Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks.
Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks.
```bash
--progressive_layer_drop
......
......@@ -12,9 +12,9 @@ efficient, and effective.
<p align="center"><i><b>Minimal Code Change</b></i></p>
DeepSpeed delivers extreme-scale model training for everyone, from data scientists training on massive supercomputers to those training on low-end clusters or even on a single GPU:
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extreme scale: Using current generation of GPU clusters with hundreds of devices, 3D parallelism of DeepSpeed can efficiently train deep learning models with trillions of parameters.
* Extremely memory efficient: With just a single GPU, ZeRO-Offload of DeepSpeed can train models with over 10B parameters, 10x bigger than the state of arts, democratizing multi-billion-parameter model training such that many deep learning scientists can explore bigger and better models.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely long sequence length: Sparse attention of DeepSpeed powers an order-of-magnitude longer input sequence and obtains up to 6x faster execution comparing with dense transformers.
* Extremely communication efficient: 3D parallelism improves communication efficiency allows users to train multi-billion-parameter models 2–7x faster on clusters with limited network bandwidth. 1-bit Adam reduces communication volume by up to 5x while achieving similar convergence efficiency to Adam, allowing for scaling to different types of GPU clusters and networks.
Early adopters of DeepSpeed have already produced
......
......@@ -101,6 +101,17 @@ class SimpleOptimizer(torch.optim.Optimizer):
return loss
class PLD_SimpleModel(SimpleModel):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)
def forward(self, x, y, **kwargs):
pld = kwargs.get('progressive_layer_drop', False)
theta = kwargs.get('pld_theta', 1.0)
hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
return hidden_dim
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
......
import numpy as np
import deepspeed
import pytest
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from common import distributed_test
from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_schedule(tmpdir, theta):
gamma = 0.001
pld_scheduler = ProgressiveLayerDrop(theta, gamma)
for i in range(10):
pld_scheduler.update_state(i)
expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
actual_theta = pld_scheduler.get_theta()
assert expected_theta == actual_theta
@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0])
def test_pld_model(tmpdir, theta):
gamma = 0.001
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": 'Adam',
"params": {
"lr": 0.0001
}
},
"fp16": {
"enabled": True
},
"progressive_layer_drop": {
"enabled": True,
"theta": theta,
"gamma": gamma
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = PLD_SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_pld_model(args, model, hidden_dim, theta, gamma):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for i, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
expected_theta = (1. - theta) * np.exp(-gamma * i) + theta
actual_theta = model.get_pld_theta()
assert expected_theta == actual_theta
_test_pld_model(args=args,
model=model,
hidden_dim=hidden_dim,
theta=theta,
gamma=gamma)
def test_non_pld_model(tmpdir):
gamma = 0.001
theta = 0.5
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"optimizer": {
"type": 'Adam',
"params": {
"lr": 0.0001
}
},
"fp16": {
"enabled": True
},
"progressive_layer_drop": {
"enabled": True,
"theta": theta,
"gamma": gamma
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10
model = SimpleModel(hidden_dim, empty_grad=False)
@distributed_test(world_size=[1])
def _test_non_pld_model(args, model, hidden_dim):
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=1,
hidden_dim=hidden_dim,
device=model.device)
for i, batch in enumerate(data_loader):
with pytest.raises(TypeError):
loss = model(batch[0], batch[1])
_test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册