未验证 提交 e80ae088 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Empty ZeRO3 partition cache (#3060)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 5cdf3593
......@@ -9,6 +9,7 @@ import torch
import hashlib
from collections import defaultdict, OrderedDict, deque
from shutil import copyfile
import gc
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
......@@ -3546,3 +3547,12 @@ class DeepSpeedEngine(Module):
self.checkpoint_engine.commit(tag)
return True
def empty_partition_cache(self):
"""
Release GPU memory consumed by offloaded model parameters.
"""
if hasattr(self.optimizer, 'empty_partition_cache'):
self.optimizer.empty_partition_cache()
gc.collect()
get_accelerator().empty_cache()
......@@ -259,6 +259,9 @@ class DeepSpeedZeRoOffload(object):
return self.param_coordinators[training]
def empty_partition_cache(self):
self.partition_all_parameters()
def _convert_to_zero_parameters(self, ds_config, module, mpu):
non_zero_params = [p for p in module.parameters() if not is_zero_param(p)]
if non_zero_params:
......@@ -321,7 +324,7 @@ class DeepSpeedZeRoOffload(object):
if param.ds_numel + total_persistent_parameters > model_threshold:
continue
if param.ds_numel < param_threshold:
if param.ds_numel <= param_threshold:
params_count += 1
param.ds_persist = True
persistent_params.append(param)
......
......@@ -2467,6 +2467,9 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if len(self.persistent_parameters) > 0:
self.persistent_parameters[0].all_gather(self.persistent_parameters)
def empty_partition_cache(self):
self.parameter_offload.empty_partition_cache()
def _handle_overflow(cpu_sum, x, i):
import math
......
......@@ -331,3 +331,29 @@ These routines can be used in a training loop as shown in the following snippet.
[...]
optimizer.step()
GPU Memory Management
---------------------
By default at the end of training with ZeRO stage 3 some parameters could remain unpartitioned and use up some gpu memory.
This is done on purpose as an optimization should you resume training again. If you'd like to clear out the cached
parameters that use up gpu memory, you can call ``empty_partition_cache`` method of a DeepSpeed engine.
.. autofunction::deepspeed.DeepSpeedEngine.empty_partition_cache
The following code snippet illustrates this functionality.
.. code-block:: python
with zero.Init():
model = MyLargeModel()
ds_engine, _, _, _ = deepspeed.initialize(model, ...)
for batch in ...:
loss = ds_engine(batch)
ds_engine.backward(batch)
ds_engine.step()
# Free GPU memory consumed by model parameters
ds_engine.empty_partition_cache()
......@@ -1422,3 +1422,49 @@ class TestZeroOffloadOptim(DistributedTest):
model, _, _, _ = deepspeed.initialize(model=model,
optimizer=optimizer,
config=config_dict)
@pytest.mark.parametrize('training', [True, False])
class TestZeroPartitionCache(DistributedTest):
world_size = 1
def test_training_partition_cache(self, training):
hidden_dim = 10
config_dict = {
"train_batch_size": 2,
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": hidden_dim
}
}
if training:
config_dict["optimizer"] = {"type": "Adam"}
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, empty_grad=False)
model, _, _, _ = deepspeed.initialize(model=model, config=config_dict)
dtype = torch.half
data_loader = random_dataloader(model=model,
total_samples=6,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
if training:
model.backward(loss)
model.step()
persist_param_size = sum([p.numel() for p in model.parameters() if p.ds_persist])
assert persist_param_size >= sum([p.numel() for p in model.parameters()])
model.empty_partition_cache()
assert sum([p.numel() for p in model.parameters()]) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册