diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index bc581dc2215e553f2532277e2fcc2498e4399b66..a4d79697a1110ef13f6ec8da488b5539b8cbb344 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2101,16 +2101,17 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): return grad_dict - def _fp32_state_allgather(self, param, fp32_state): - reduce_buffer = torch.zeros(self.partition_count * fp32_state.numel(), + def _fp32_state_allgather(self, param, fp32_state_partition): + reduce_buffer = torch.zeros(self.partition_count * fp32_state_partition.numel(), dtype=torch.float32, device=param.device).flatten() my_rank = dist.get_rank(group=self.dp_process_group) partitions = [ reduce_buffer.narrow(0, - fp32_state.numel() * i, fp32_state.numel()) for i in range(self.partition_count) + fp32_state_partition.numel() * i, fp32_state_partition.numel()) + for i in range(self.partition_count) ] - partitions[my_rank].data.copy_(fp32_state.data, non_blocking=False) + partitions[my_rank].data.copy_(fp32_state_partition.data, non_blocking=False) dist.all_gather(partitions, partitions[my_rank], group=self.dp_process_group) @@ -2125,19 +2126,16 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): if self.offload_optimizer: group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, - num_elements).to(device=param.device) + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow(0, dest_offset, num_elements) else: fp32_grad = self.__param_id_to_grad_partition[param.ds_id].float() return self._fp32_state_allgather(param, fp32_grad) - def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: - if not param.requires_grad: - return None - + def _get_fp32_opt_state_partition(self, param, optim_state_key=None): if not get_accelerator().is_synchronized_device(): self.reduce_and_partition_stream.synchronize() + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] if self._swappable_optimizer_subgroup(group_idx): @@ -2145,16 +2143,41 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): fp32_param = self.fp32_partitioned_groups_flat[group_idx] if optim_state_key is None: - fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = fp32_param.narrow(0, dest_offset, num_elements) else: - fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow( - 0, dest_offset, num_elements).to(device=param.device) + fp32_opt_state = self.optimizer.state[fp32_param][optim_state_key].narrow(0, dest_offset, num_elements) + + return fp32_opt_state, group_idx + def get_full_hp_param(self, param, optim_state_key=None) -> Tensor: + if not param.requires_grad: + return None + + fp32_opt_state, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) hp_param = self._fp32_state_allgather(param, fp32_opt_state) + if self._swappable_optimizer_subgroup(group_idx): self._optimizer_states_and_gradient_swap_out(group_idx) + return hp_param + def set_full_hp_param(self, value, param, optim_state_key=None): + if not param.requires_grad: + return + + assert value.numel( + ) == param.ds_numel, f" Number of elements do not match: {value.numel()} != {param.ds_numel}" + + fp32_opt_state_partition, group_idx = self._get_fp32_opt_state_partition(param, optim_state_key) + my_rank = dist.get_rank(group=self.dp_process_group) + value_partition = value.flatten().narrow(0, + fp32_opt_state_partition.numel() * my_rank, + fp32_opt_state_partition.numel()) + fp32_opt_state_partition.data.copy_(value_partition.data) + + if self._swappable_optimizer_subgroup(group_idx): + self._optimizer_states_and_gradient_swap_out(group_idx) + @instrument_w_nvtx def _partition_all_parameters(self): self.parameter_offload.partition_all_parameters() diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 4e7f6b61d075c4041f62b9de49a9d53ae42306c1..b6668b5ff5ce682e094ae7142120d061f4e69211 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -12,6 +12,8 @@ from .nvtx import instrument_w_nvtx # TODO: Move tensor fragment and mixed precision to zero utils from .tensor_fragment import tensor_fragment, get_full_hp_param, get_hp_fragment_mapping, fragment_address, get_full_hp_grad from .tensor_fragment import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state +from .tensor_fragment import set_full_hp_param +from .tensor_fragment import safe_set_full_fp32_param, safe_set_full_optimizer_state from .mixed_precision_linkage import link_hp_params from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index ecc29e930954d174772d427f0eeea95867fe200d..b1afa8f00aa38eeed8ca0ae423fb07fda8e37dfc 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -5,6 +5,7 @@ import types from deepspeed.utils import get_full_hp_param, get_full_hp_grad, get_hp_fragment_mapping +from deepspeed.utils import set_full_hp_param def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, @@ -27,6 +28,7 @@ def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_gr lp_param._dp_group = dp_group lp_param.get_full_hp_param = types.MethodType(get_full_hp_param, lp_param) lp_param.get_full_hp_grad = types.MethodType(get_full_hp_grad, lp_param) + lp_param.set_full_hp_param = types.MethodType(set_full_hp_param, lp_param) # lp_param overlaps with partition if both are true # 1) current_offset < partition_end, diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index ef09edd3da1e07fec082cc81a99d1bd707d1ad68..18e373799ab73893c43cfbef0511b1febb86a871 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -45,22 +45,31 @@ class tensor_fragment: def get_optim_state_keys(self): return list(self.optim_fragment.keys()) + def get_hp_fragment(self, optim_state_key=None): + if optim_state_key is None: + return self.hp_fragment + return self.get_optim_state_fragment(optim_state_key) + def get_full_hp_param(self, optim_state_key=None): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: lp_frag_address = self._hp_mapping.lp_fragment_address reduce_fragment = torch.narrow(reduce_buffer, 0, lp_frag_address.start, lp_frag_address.numel) - if optim_state_key is None: - hp_fragment = self._hp_mapping.hp_fragment - else: - hp_fragment = self._hp_mapping.get_optim_state_fragment(optim_state_key) - + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) reduce_fragment.data.copy_(hp_fragment.data) dist.all_reduce(reduce_buffer, group=self._dp_group) return reduce_buffer.reshape_as(self) +def set_full_hp_param(self, value, optim_state_key=None): + if self._hp_mapping is not None: + lp_frag_address = self._hp_mapping.lp_fragment_address + value_fragment = torch.narrow(value.flatten(), 0, lp_frag_address.start, lp_frag_address.numel) + hp_fragment = self._hp_mapping.get_hp_fragment(optim_state_key) + hp_fragment.data.copy_(value_fragment.data) + + def get_full_hp_grad(self): reduce_buffer = torch.zeros_like(self, dtype=torch.float32).flatten() if self._hp_mapping is not None: @@ -105,11 +114,28 @@ def safe_get_full_fp32_param(param): return None +def safe_set_full_fp32_param(param, value): + """Update the partitioned fp32 parameter of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value) + + def safe_get_full_optimizer_state(param, optim_state_key): """Assemble and return the fp32 optimizer state of a low-precision (e.g., fp16) parameter. Args: param (``torch.nn.Parameter``): A model parameter + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) """ # ZeRO stage 3 param if hasattr(param, 'ds_id'): @@ -121,6 +147,23 @@ def safe_get_full_optimizer_state(param, optim_state_key): return None +def safe_set_full_optimizer_state(param, value, optim_state_key): + """Update the partitioned fp32 optimizer state of a low-precision (e.g., fp16) parameter. + + Args: + param (``torch.nn.Parameter``): A model parameter + value (``torch.Tensor``): New value + optim_state_key (``string``): Key value of optimizer state (e.g., `exp_avg` in Adam optimizer) + """ + # ZeRO stage 3 param + if hasattr(param, 'ds_id'): + param._z3_optimizer.set_full_hp_param(value, param, optim_state_key) + + # ZeRO stage 1, 2, and bf16_optimizer params + if hasattr(param, '_hp_mapping'): + param.set_full_hp_param(value, optim_state_key) + + # TODO: Figure out the correct return dtype def safe_get_full_grad(param): """Assemble and return the fp32 gradient of a low-precision (e.g., fp16) parameter. @@ -142,6 +185,9 @@ def safe_get_full_grad(param): return None +# TODO: Implement API for setting ZeRO partitioned gradients + + def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, partition_start, partition_size, optimizer_state_dict): lp_end = lp_param.numel() + lp_start diff --git a/docs/code-docs/source/zero3.rst b/docs/code-docs/source/zero3.rst index 333b29ed98d8e84e6e211dcf74b304762f44fd88..56a7987dc4968e8e7608de88b00cdba17423e849 100644 --- a/docs/code-docs/source/zero3.rst +++ b/docs/code-docs/source/zero3.rst @@ -376,6 +376,35 @@ These routines can be used in a training loop as shown in the following snippet. optimizer.step() + +Modifying Partitioned States +---------------------------- + +Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following two routines for modifying the fp32 master parameters and the fp32 optimizer states. + +.. autofunction:: deepspeed.utils.safe_set_full_fp32_param + +.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state + + +These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet. + +.. code-block:: python + + [...] + from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state + # Here is an example to zero all the fp32 parameters and optimizer states. + for n, lp in model.named_parameters(): + # Assume zero stage 1 or 2, since stage 3 requires a gather to assemble lp + zero_tensor = torch.zeros_like(lp) + + hp = safe_set_full_fp32_param(lp, zero_tensor) + exp_avg = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg") + exp_avg_sq = safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq") + + [...] + + GPU Memory Management --------------------- diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index 47550256141832bced13a82f2645c23d828208a0..63d05ab6d352c7c98c4de22ed2d676721d8ea30e 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -8,14 +8,19 @@ import deepspeed.comm as dist import torch from unit.common import DistributedTest -from unit.simple_model import random_dataloader +from unit.simple_model import random_dataloader, SimpleModel from unit.util import bf16_required_version_check import deepspeed from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state +from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.aio import AsyncIOBuilder +WEIGHT_KEY = 'weight' +FIRST_ORDER_KEY = 'exp_avg' +SECOND_ORDER_KEY = 'exp_avg_sq' + def validate_full_tensors(model): for _, lp in model.named_parameters(): @@ -73,7 +78,7 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype): @pytest.mark.parametrize('frozen_weights', [True, False]) -class TestTensorFragment(DistributedTest): +class TestTensorFragmentGet(DistributedTest): # Need multiple gpus to test possible hanging world_size = 2 reuse_dist_env = True @@ -150,3 +155,104 @@ class TestTensorFragment(DistributedTest): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16) + + +def create_random_values(model, key_list, group): + param_values = {} + for n, lp in model.named_parameters(): + param_shape = lp.ds_shape if hasattr(lp, 'ds_id') else lp.shape + param_values[n] = {} + for key in key_list: + rand_value = torch.rand(param_shape, dtype=torch.float32, device=model.device) + dist.broadcast(rand_value, src=0, group=group) + param_values[n][key] = rand_value + return param_values + + +def set_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, value_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + safe_set_full_fp32_param(lp, value_tensor) + else: + safe_set_full_optimizer_state(lp, value_tensor, key) + + +def validate_param_values_with_dict(model, value_dict): + for n, lp in model.named_parameters(): + for key, expected_tensor in value_dict[n].items(): + if key == WEIGHT_KEY: + actual_tensor = safe_get_full_fp32_param(lp) + else: + actual_tensor = safe_get_full_optimizer_state(lp, key) + assert torch.equal(expected_tensor, actual_tensor) + + +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) +class TestTensorFragmentUpdate(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('offload_device', [OffloadDeviceEnum.none, OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) + def test_zero_fragments(self, tmpdir, zero_stage, offload_device, dtype): + + if dtype == torch.bfloat16 and not bf16_required_version_check(accelerator_check=False): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if offload_device == OffloadDeviceEnum.nvme: + if zero_stage != 3: + pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": zero_stage, + } + } + + if offload_device == OffloadDeviceEnum.cpu: + config_dict["zero_optimization"]["offload_optimizer"] = {"device": offload_device} + elif offload_device == OffloadDeviceEnum.nvme: + config_dict["zero_optimization"]["offload_optimizer"] = { + "device": offload_device, + "nvme_path": str(tmpdir) + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + hidden_dim = 128 + if zero_stage == 3: + config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=4) + else: + model = SimpleModel(hidden_dim, nlayers=4) + + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + world = dist.get_world_size() + group = dist.new_group(ranks=list(range(world))) + + dist.barrier() + optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + optim_state_values = create_random_values(model, optim_keys, group) + set_param_values_with_dict(model, optim_state_values) + validate_param_values_with_dict(model, optim_state_values) + + # Needed in ZeRO 3. Not doing so can leak memory. + model.destroy()