未验证 提交 a23cda6c 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Allow modification of zero partitioned parameters (#4192)

* Modify zero parameters

* Docs

* py3.6 compatibility

* Update docs

* Update deepspeed/runtime/zero/stage3.py
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>

* Add TODO

* Formatting

---------
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 f96c1c0a
......@@ -2101,16 +2101,17 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
return grad_dict
def _fp32_state_allgather(self, param, fp32_state):
reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(),
def _fp32_state_allgather(self, param, fp32_state_partition):
reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(),
dtype=torch.float32,
device=param.device).flatten()
my_rank = dist.get_rank(group=self.dp_process_group)
partitions = [
reduce_buffer.narrow(0,
fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count)
fp32_state_partition.numel() * i, fp32_state_partition.numel())
for i in range(self.partition_count)
]
partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False)
partitions[my_rank].data.copy_(fp32_state_partition.data, non_blocking=False)
dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group)
......@@ -2125,19 +2126,16 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if self.offload_optimizer:
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset,
num_elements).to(device=param.device)
fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements)
else:
fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float()
return self._fp32_state_allgather(param, fp32_grad)
def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
def _get_fp32_opt_state_partition(self, param, optim_state_key=None):
if not get_accelerator().is_synchronized_device():
self.reduce_and_partition_stream.synchronize()
group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)]
if self._swappable_optimizer_subgroup(group_idx):
......@@ -2145,16 +2143,41 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
fp32_param = self.fp32_partitioned_groups_flat[group_idx]
if optim_state_key is None:
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements)
else:
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(
0, dest_offset, num_elements).to(device=param.device)
fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements)
return fp32_opt_state, group_idx
def get_full_hp_param(self, param, optim_state_key=None) -> Tensor:
if not param.requires_grad:
return None
fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
hp_param = self._fp32_state_allgather(param, fp32_opt_state)
if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
return hp_param
def set_full_hp_param(self, value, param, optim_state_key=None):
if not param.requires_grad:
return
assert value.numel(
) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}"
fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key)
my_rank = dist.get_rank(group=self.dp_process_group)
value_partition = value.flatten().narrow(0,
fp32_opt_state_partition.numel() * my_rank,
fp32_opt_state_partition.numel())
fp32_opt_state_partition.data.copy_(value_partition.data)
if self._swappable_optimizer_subgroup(group_idx):
self._optimizer_states_and_gradient_swap_out(group_idx)
@instrument_w_nvtx
def _partition_all_parameters(self):
self.parameter_offload.partition_all_parameters()
......
......@@ -12,6 +12,8 @@ from .nvtx import instrument_w_nvtx
# TODO: Move tensor fragment and mixed precision to zero utils
from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad
from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from .tensor_fragment import set_full_hp_param
from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state
from .mixed_precision_linkage import link_hp_params
from deepspeed.runtime.dataloader import RepeatingLoader
from .numa import get_numactl_cmd
......@@ -5,6 +5,7 @@
import types
from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping
from deepspeed.utils import set_full_hp_param
def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
......@@ -27,6 +28,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr
lp_param._dp_group = dp_group
lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param)
lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param)
lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param)
# lp_param overlaps with partition if both are true
# 1) current_offset < partition_end,
......
......@@ -45,22 +45,31 @@ class tensor_fragment:
def get_optim_state_keys(self):
return list(self.optim_fragment.keys())
def get_hp_fragment(self, optim_state_key=None):
if optim_state_key is None:
return self.hp_fragment
return self.get_optim_state_fragment(optim_state_key)
def get_full_hp_param(self, optim_state_key=None):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel)
if optim_state_key is None:
hp_fragment = self._hp_mapping.hp_fragment
else:
hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key)
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
reduce_fragment.data.copy_(hp_fragment.data)
dist.all_reduce(reduce_buffer, group=self._dp_group)
return reduce_buffer.reshape_as(self)
def set_full_hp_param(self, value, optim_state_key=None):
if self._hp_mapping is not None:
lp_frag_address = self._hp_mapping.lp_fragment_address
value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel)
hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key)
hp_fragment.data.copy_(value_fragment.data)
def get_full_hp_grad(self):
reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten()
if self._hp_mapping is not None:
......@@ -105,11 +114,28 @@ def safe_get_full_fp32_param(param):
return None
def safe_set_full_fp32_param(param, value):
"""Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param)
# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value)
def safe_get_full_optimizer_state(param, optim_state_key):
"""Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
......@@ -121,6 +147,23 @@ def safe_get_full_optimizer_state(param, optim_state_key):
return None
def safe_set_full_optimizer_state(param, value, optim_state_key):
"""Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter.
Args:
param (``torch.nn.Parameter``): A model parameter
value (``torch.Tensor``): New value
optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer)
"""
# ZeRO stage 3 param
if hasattr(param, 'ds_id'):
param._z3_optimizer.set_full_hp_param(value, param, optim_state_key)
# ZeRO stage 1, 2, and bf16_optimizer params
if hasattr(param, '_hp_mapping'):
param.set_full_hp_param(value, optim_state_key)
# TODO: Figure out the correct return dtype
def safe_get_full_grad(param):
"""Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter.
......@@ -142,6 +185,9 @@ def safe_get_full_grad(param):
return None
# TODO: Implement API for setting ZeRO partitioned gradients
def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload,
param_group_index, partition_start, partition_size, optimizer_state_dict):
lp_end = lp_param.numel() + lp_start
......
......@@ -376,6 +376,35 @@ These routines can be used in a training loop as shown in the following snippet.
optimizer.step()
Modifying Partitioned States
----------------------------
Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states.
.. autofunction:: deepspeed.utils.safe_set_full_fp32_param
.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state
These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
.. code-block:: python
[...]
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
# Here is an example to zero all the fp32 parameters and optimizer states.
for n, lp in model.named_parameters():
# Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp
zero_tensor = torch.zeros_like(lp)
hp = safe_set_full_fp32_param(lp, zero_tensor)
exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
[...]
GPU Memory Management
---------------------
......
......@@ -8,14 +8,19 @@ import deepspeed.comm as dist
import torch
from unit.common import DistributedTest
from unit.simple_model import random_dataloader
from unit.simple_model import random_dataloader, SimpleModel
from unit.util import bf16_required_version_check
import deepspeed
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.aio import AsyncIOBuilder
WEIGHT_KEY = 'weight'
FIRST_ORDER_KEY = 'exp_avg'
SECOND_ORDER_KEY = 'exp_avg_sq'
def validate_full_tensors(model):
for _, lp in model.named_parameters():
......@@ -73,7 +78,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype):
@pytest.mark.parametrize('frozen_weights', [True, False])
class TestTensorFragment(DistributedTest):
class TestTensorFragmentGet(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True
......@@ -150,3 +155,104 @@ class TestTensorFragment(DistributedTest):
hidden_dim = 128
model = MyModel(hidden_dim, frozen_weights)
run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16)
def create_random_values(model, key_list, group):
param_values = {}
for n, lp in model.named_parameters():
param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape
param_values[n] = {}
for key in key_list:
rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device)
dist.broadcast(rand_value, src=0, group=group)
param_values[n][key] = rand_value
return param_values
def set_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, value_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
safe_set_full_fp32_param(lp, value_tensor)
else:
safe_set_full_optimizer_state(lp, value_tensor, key)
def validate_param_values_with_dict(model, value_dict):
for n, lp in model.named_parameters():
for key, expected_tensor in value_dict[n].items():
if key == WEIGHT_KEY:
actual_tensor = safe_get_full_fp32_param(lp)
else:
actual_tensor = safe_get_full_optimizer_state(lp, key)
assert torch.equal(expected_tensor, actual_tensor)
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
class TestTensorFragmentUpdate(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True
@pytest.mark.parametrize('zero_stage', [1, 2, 3])
@pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme])
def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype):
if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False):
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
)
if offload_device == OffloadDeviceEnum.nvme:
if zero_stage != 3:
pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}")
if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]:
pytest.skip('Skip tests since async-io is not compatible')
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"zero_optimization": {
"stage": zero_stage,
}
}
if offload_device == OffloadDeviceEnum.cpu:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device}
elif offload_device == OffloadDeviceEnum.nvme:
config_dict["zero_optimization"]["offload_optimizer"] = {
"device": offload_device,
"nvme_path": str(tmpdir)
}
if dtype == torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
hidden_dim = 128
if zero_stage == 3:
config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=4)
else:
model = SimpleModel(hidden_dim, nlayers=4)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
world = dist.get_world_size()
group = dist.new_group(ranks=list(range(world)))
dist.barrier()
optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY]
optim_state_values = create_random_values(model, optim_keys, group)
set_param_values_with_dict(model, optim_state_values)
validate_param_values_with_dict(model, optim_state_values)
# Needed in ZeRO 3. Not doing so can leak memory.
model.destroy()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册