From 648f7bfa5009484b822064d0c28d377da6dd71a0 Mon Sep 17 00:00:00 2001 From: Rana Ali Amjad Date: Mon, 1 Nov 2021 14:22:09 -0700 Subject: [PATCH] Bfloat16 zero2 (#1398) * Changes for bfloat16 Zero2 * Cleaned up additional comments and debugging code * Adapted fp16_master_weights_and_grads option to cover BF16 * Reverted fp16_master_weights_and_gradients extension to BFloat16 and minor cleanup * Fixed formatting and variable naming errors recognized in testing * Added relevant unit tests for bfloat16 with ZeRO-2 * Updates conditions for skipping BFloat16 unit tests * Added check for NCCL inconsistent version naming convention * Update skip message for Bfloat16 tests to mention additional checks Co-authored-by: Olatunji Ruwase --- csrc/includes/type_shim.h | 10 + deepspeed/__init__.py | 2 +- deepspeed/runtime/config.py | 18 +- deepspeed/runtime/constants.py | 16 ++ deepspeed/runtime/engine.py | 33 ++-- deepspeed/runtime/zero/stage2.py | 138 +++++++------ tests/unit/test_bf16.py | 321 +++++++++++++++++++++++++++++++ tests/unit/test_zero.py | 3 +- tests/unit/util.py | 21 ++ 9 files changed, 475 insertions(+), 87 deletions(-) create mode 100644 tests/unit/test_bf16.py diff --git a/csrc/includes/type_shim.h b/csrc/includes/type_shim.h index ba1e188f..4f4e7a53 100644 --- a/csrc/includes/type_shim.h +++ b/csrc/includes/type_shim.h @@ -26,6 +26,11 @@ __VA_ARGS__; \ break; \ } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } @@ -46,6 +51,11 @@ __VA_ARGS__; \ break; \ } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 5ff4c541..3eeb6d9d 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -1,6 +1,7 @@ ''' Copyright 2020 The Microsoft DeepSpeed Team ''' + import sys import types from typing import Optional, Union @@ -124,7 +125,6 @@ def initialize(args=None, __git_hash__, __git_branch__), ranks=[0]) - assert model is not None, "deepspeed.initialize requires a model" if not isinstance(model, PipelineModule): diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index d214ce74..a23d2aaf 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -116,6 +116,15 @@ def get_fp16_enabled(param_dict): return False +def get_bfloat16_enabled(param_dict): + if BFLOAT16 in param_dict.keys(): + return get_scalar_param(param_dict[BFLOAT16], + BFLOAT16_ENABLED, + BFLOAT16_ENABLED_DEFAULT) + else: + return False + + def get_fp16_master_weights_and_grads_enabled(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], @@ -130,6 +139,8 @@ def get_loss_scale(param_dict): return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) + elif get_bfloat16_enabled(param_dict): + return 1.0 else: return FP16_LOSS_SCALE_DEFAULT @@ -139,6 +150,8 @@ def get_initial_dynamic_scale(param_dict): initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) + elif get_bfloat16_enabled(param_dict): + initial_scale_power = 0 else: initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT @@ -791,6 +804,9 @@ class DeepSpeedConfig(object): self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) + self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) self.amp_enabled = get_amp_enabled(param_dict) @@ -966,7 +982,7 @@ class DeepSpeedConfig(object): assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." def _do_warning_check(self): - fp16_enabled = self.fp16_enabled or self.zero_enabled + fp16_enabled = self.fp16_enabled vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index a88d1074..daf9ffea 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -107,6 +107,22 @@ GRADIENT_ACCUMULATION_STEPS_DEFAULT = None SPARSE_GRADIENTS = "sparse_gradients" SPARSE_GRADIENTS_DEFAULT = False +######################################### +# BFLOAT16 support +######################################### +# BFLOAT16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +BFLOAT16_FORMAT = ''' +BFLOAT16 parameters should be of the format: +"bfloat16": { + "enabled": true +} +''' +BFLOAT16 = "bfloat16" + +BFLOAT16_ENABLED = "enabled" +BFLOAT16_ENABLED_DEFAULT = False + ######################################### # FP16 support ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 92ddfd77..87eb317f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -83,6 +83,7 @@ def split_half_float_double_sparse(tensors): "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", + "torch.cuda.BFloat16Tensor", SparseTensor.type() ] @@ -195,7 +196,6 @@ class DeepSpeedEngine(Module): # Configure wall clock timer self.timers = SynchronizedWallClockTimer() - # Throughput timer self.tput_timer = ThroughputTimer( batch_size=self.train_micro_batch_size_per_gpu(), @@ -530,6 +530,9 @@ class DeepSpeedEngine(Module): def fp16_enabled(self): return self._config.fp16_enabled + def bfloat16_enabled(self): + return self._config.bfloat16_enabled + def fp16_master_weights_and_gradients(self): return self._config.fp16_master_weights_and_gradients @@ -762,6 +765,8 @@ class DeepSpeedEngine(Module): f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}" ) self.module.half() + elif self.bfloat16_enabled(): + self.module.bfloat16() else: if not all( [param.dtype == torch.float for param in self.module.parameters()]): @@ -899,7 +904,7 @@ class DeepSpeedEngine(Module): ) self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): - assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" + assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" amp_params = self.amp_params() if self.global_rank == 0: logger.info(f"Initializing AMP with these params: {amp_params}") @@ -1537,9 +1542,13 @@ class DeepSpeedEngine(Module): # Quantize the updated parameter if there is no overflow if self.quantizer: + if self.fp16_enabled(): + tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( + ) == 2 else self.optimizer.fp16_groups + else: + tensor_to_quantize = self.optimizer.param_groups self.quantizer.quantize( - (self.optimizer.fp16_groups - if self.fp16_enabled() else self.optimizer.param_groups), + tensor_to_quantize, (self.optimizer.overflow if self.fp16_enabled() else False), self.eigenvalue_enabled(), block_eigenvalue) @@ -2261,7 +2270,6 @@ class DeepSpeedEngine(Module): method will hang waiting to synchronize with other processes if it's called just for the process with rank 0. """ - if self.zero_optimization_partition_weights(): # Prepare for state_dict() by ensuring all parameters are partitioned self.optimizer.save_checkpoint_prologue() @@ -2501,22 +2509,23 @@ class DeepSpeedEngine(Module): will be missing and others unsaved and then it'd be impossible to reconstruct state_dict from the flattened weights. - optimizer.fp16_groups seems to be the easiest to use as it's in all zeroX versions. + optimizer.bit16_groups seems to be the easiest to use as it's in all zeroX versions. """ param_group_shapes = [] cnt = 0 numel = 0 - # zero2 started using a round_robin_fp16_groups which is a shuffled version of fp16_groups - + # zero2 started using a round_robin_bit16_groups which is a shuffled version of bit16_groups - # if we don't use it, we get parameters ordered incorrectly - if hasattr(self.optimizer, "round_robin_fp16_groups"): - fp16_groups = self.optimizer.round_robin_fp16_groups + if hasattr(self.optimizer, "round_robin_bit16_groups"): + bit16_groups = self.optimizer.round_robin_bit16_groups else: - fp16_groups = self.optimizer.fp16_groups + bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( + ) == 2 else self.optimizer.fp16_groups - for fp16_group in fp16_groups: + for bit16_group in bit16_groups: param_shapes = OrderedDict() - for param in fp16_group: + for param in bit16_group: cnt += 1 numel += param.ds_numel if hasattr(param, "ds_numel") else param.numel() shape = param.ds_shape if hasattr(param, "ds_shape") else param.shape diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 2d37a49e..fdaf8124 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -34,7 +34,8 @@ def split_half_float_double(tensors): dtypes = [ "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor" + "torch.cuda.DoubleTensor", + "torch.cuda.BFloat16Tensor" ] buckets = [] for i, dtype in enumerate(dtypes): @@ -203,13 +204,13 @@ class FP16_DeepSpeedZeroOptimizer(object): assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" # param flattened by groups - self.fp16_groups = [] - self.fp16_groups_flat = [] + self.bit16_groups = [] + self.bit16_groups_flat = [] # param partitioned by data parallel degree # this will contain a list of equal sized tensors # each of which will be updated by a different process - self.parallel_partitioned_fp16_groups = [] + self.parallel_partitioned_bit16_groups = [] # a single 32-bit partition of the parallel partitioned parameters # that this process will update @@ -239,8 +240,8 @@ class FP16_DeepSpeedZeroOptimizer(object): self.all_reduce_print = False self.dtype = self.optimizer.param_groups[0]['params'][0].dtype - self.round_robin_fp16_groups = [] - self.round_robin_fp6_indices = [] + self.round_robin_bit16_groups = [] + self.round_robin_bit16_indices = [] # padding on each partition for alignment purposes self.groups_padding = [] @@ -250,12 +251,12 @@ class FP16_DeepSpeedZeroOptimizer(object): # push this group to list before modify # TODO: Explore simplification that avoids the extra book-keeping by pushing the reordered group - self.fp16_groups.append(param_group['params']) + self.bit16_groups.append(param_group['params']) # Record padding required to align group to world size if partition_id == dist.get_world_size( group=self.real_dp_process_group[i]) - 1: - padding = get_alignment_padding(self.fp16_groups[i], + padding = get_alignment_padding(self.bit16_groups[i], self.partition_count[i]) else: padding = 0 @@ -266,7 +267,7 @@ class FP16_DeepSpeedZeroOptimizer(object): see_memory_usage(f"Before moving param group {i} to CPU") # move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_groups[i]) + move_to_cpu(self.bit16_groups[i]) see_memory_usage(f"After moving param group {i} to CPU", force=False) # Reorder group parameters for load balancing of gradient partitioning during backward among ranks. @@ -275,20 +276,20 @@ class FP16_DeepSpeedZeroOptimizer(object): # to the same rank, instead they will belong to 3 ranks (r_m+2, r_m+1, r_m). if self.round_robin_gradients: round_robin_tensors, round_robin_indices = self._round_robin_reorder( - self.fp16_groups[i], + self.bit16_groups[i], dist.get_world_size(group=self.real_dp_process_group[i]) ) else: - round_robin_tensors = self.fp16_groups[i] - round_robin_indices = list(range(len(self.fp16_groups[i]))) + round_robin_tensors = self.bit16_groups[i] + round_robin_indices = list(range(len(self.bit16_groups[i]))) - self.round_robin_fp16_groups.append(round_robin_tensors) - self.round_robin_fp6_indices.append(round_robin_indices) + self.round_robin_bit16_groups.append(round_robin_tensors) + self.round_robin_bit16_indices.append(round_robin_indices) # create flat buffer in CPU and move to GPU - self.fp16_groups_flat.append( + self.bit16_groups_flat.append( self.flatten_dense_tensors_aligned( - self.round_robin_fp16_groups[i], + self.round_robin_bit16_groups[i], self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).cuda( torch.cuda.current_device())) @@ -300,15 +301,15 @@ class FP16_DeepSpeedZeroOptimizer(object): f"After Flattening and after emptying param group {i} cache", force=False) - # set model fp16 weight to slices of flattened buffer - self._update_model_fp16_weights(i) + # set model bit16 weight to slices of flattened buffer + self._update_model_bit16_weights(i) # divide the flat weights into near equal partition equal to the data parallel degree # each process will compute on a different part of the partition data_parallel_partitions = self.get_data_parallel_partitions( - self.fp16_groups_flat[i], + self.bit16_groups_flat[i], i) - self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) + self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: @@ -318,11 +319,11 @@ class FP16_DeepSpeedZeroOptimizer(object): # a partition of the fp32 master weights that will be updated by this process if not fp16_master_weights_and_gradients: self.single_partition_of_fp32_groups.append( - self.parallel_partitioned_fp16_groups[i][partition_id].to( + self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().float().detach()) else: self.single_partition_of_fp32_groups.append( - self.parallel_partitioned_fp16_groups[i][partition_id].to( + self.parallel_partitioned_bit16_groups[i][partition_id].to( self.device).clone().half().detach()) # modify optimizer of have flat master weight @@ -330,10 +331,10 @@ class FP16_DeepSpeedZeroOptimizer(object): i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [self.single_partition_of_fp32_groups[i]] - partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( + partition_size = len(self.bit16_groups_flat[i]) / dist.get_world_size( group=self.real_dp_process_group[i]) params_in_partition, params_not_in_partition, first_offset = self.get_partition_info( - self.round_robin_fp16_groups[i], + self.round_robin_bit16_groups[i], partition_size, partition_id) @@ -374,9 +375,10 @@ class FP16_DeepSpeedZeroOptimizer(object): # simplified param id self.param_id = {} + #interesting code: unique ids being assigned to individual paramters largest_param_numel = 0 count = 0 - for i, params_group in enumerate(self.fp16_groups): + for i, params_group in enumerate(self.bit16_groups): for param in params_group: unique_id = id(param) self.param_id[unique_id] = count @@ -407,8 +409,7 @@ class FP16_DeepSpeedZeroOptimizer(object): largest_param_numel, device=torch.cuda.current_device(), dtype=self.dtype) - - for i, params_group in enumerate(self.fp16_groups): + for i, params_group in enumerate(self.bit16_groups): self.get_grad_position(i, self.params_in_partition[i], self.first_offset[i], @@ -452,8 +453,10 @@ class FP16_DeepSpeedZeroOptimizer(object): self.create_reduce_and_remove_grad_hooks() # we may have a way of fusing dynamic scale. Do not support for now - if self.dtype == torch.float or not dynamic_loss_scale: - loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + if self.dtype == torch.float or self.dtype == torch.bfloat16 or not dynamic_loss_scale: + loss_scale_value = 1.0 if ( + (self.dtype == torch.float) or + (self.dtype == torch.bfloat16)) else static_loss_scale self.dynamic_loss_scale = False self.loss_scaler = LossScaler(scale=loss_scale_value) @@ -498,16 +501,16 @@ class FP16_DeepSpeedZeroOptimizer(object): assert self.expert_dp_process_group is not None, "Expert data parallel group should be configured with MoE" assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" - def _update_model_fp16_weights(self, group_index): - updated_params = self.unflatten(self.fp16_groups_flat[group_index], - self.round_robin_fp16_groups[group_index]) - for p, q in zip(self.round_robin_fp16_groups[group_index], updated_params): + def _update_model_bit16_weights(self, group_index): + updated_params = self.unflatten(self.bit16_groups_flat[group_index], + self.round_robin_bit16_groups[group_index]) + for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): p.data = q.data # set model fp16 weight to slices of reordered flattened buffer - for param_index, param in enumerate(self.fp16_groups[group_index]): - new_index = self.round_robin_fp6_indices[group_index][param_index] - param.data = self.round_robin_fp16_groups[group_index][new_index].data + for param_index, param in enumerate(self.bit16_groups[group_index]): + new_index = self.round_robin_bit16_indices[group_index][param_index] + param.data = self.round_robin_bit16_groups[group_index][new_index].data def _round_robin_reorder(self, tensor_list, num_partitions): @@ -540,7 +543,7 @@ class FP16_DeepSpeedZeroOptimizer(object): def initialize_optimizer_states(self): - for i, group in enumerate(self.fp16_groups): + for i, group in enumerate(self.bit16_groups): single_grad_partition = torch.zeros( int(self.partition_size[i]), dtype=self.single_partition_of_fp32_groups[i].dtype, @@ -560,7 +563,6 @@ class FP16_DeepSpeedZeroOptimizer(object): ######################################################################### #################### ZeRO Stage 1 - reduce gradients #################### ######################################################################### - def reduce_gradients(self, pipeline_parallel=False): world_size = dist.get_world_size(self.dp_process_group) my_rank = dist.get_rank(self.dp_process_group) @@ -575,11 +577,10 @@ class FP16_DeepSpeedZeroOptimizer(object): self.ipg_index = 0 if not self.overlap_comm: - for i, group in enumerate(self.fp16_groups): + for i, group in enumerate(self.bit16_groups): for param in group: if param.grad is not None: self.reduce_ready_partitions_and_remove_grads(param, i) - # reduce any pending grads in either hook/non-hook case self.overlapping_partition_gradients_reduce_epilogue() @@ -596,8 +597,7 @@ class FP16_DeepSpeedZeroOptimizer(object): def initialize_gradient_partitioning_data_structures(self): - for i, param_group in enumerate(self.round_robin_fp16_groups): - + for i, param_group in enumerate(self.round_robin_bit16_groups): total_partitions = dist.get_world_size(group=self.real_dp_process_group[i]) self.param_to_partition_ids[i] = {} @@ -638,7 +638,7 @@ class FP16_DeepSpeedZeroOptimizer(object): self._clear_previous_reduced_grads() if self.cpu_offload is False: - for i, _ in enumerate(self.fp16_groups): + for i, _ in enumerate(self.bit16_groups): if not i in self.averaged_gradients or self.averaged_gradients[i] is None: self.averaged_gradients[i] = self.get_flat_partition( @@ -671,7 +671,7 @@ class FP16_DeepSpeedZeroOptimizer(object): # sets remaining grads to the total number of grads in each partition # set is grad computed to false for all grads in partition def reset_partition_gradient_structures(self): - for i, _ in enumerate(self.fp16_groups): + for i, _ in enumerate(self.bit16_groups): total_partitions = dist.get_world_size(group=self.real_dp_process_group[i]) for partition_id in range(total_partitions): self.is_partition_reduced[i][partition_id] = False @@ -741,7 +741,7 @@ class FP16_DeepSpeedZeroOptimizer(object): def create_reduce_and_remove_grad_hooks(self): self.grad_accs = [] - for i, param_group in enumerate(self.fp16_groups): + for i, param_group in enumerate(self.bit16_groups): for param in param_group: if param.requires_grad: @@ -1116,7 +1116,6 @@ class FP16_DeepSpeedZeroOptimizer(object): # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_process_group) @@ -1133,7 +1132,6 @@ class FP16_DeepSpeedZeroOptimizer(object): return total_norm ############################################################################################ - def copy_grads_in_partition(self, param): if self.cpu_offload: @@ -1299,7 +1297,6 @@ class FP16_DeepSpeedZeroOptimizer(object): param.grad = torch.zero_like(param) ######################Reduction Related Methods############################## - def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): rank = None tensor = self.flatten(bucket) @@ -1444,7 +1441,7 @@ class FP16_DeepSpeedZeroOptimizer(object): """ # FP32 grad should never exist. # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: + for group in self.bit16_groups: for p in group: if set_grads_to_None: p.grad = None # epilogue and in step @@ -1505,7 +1502,6 @@ class FP16_DeepSpeedZeroOptimizer(object): total_norm += param_norm.item()**2 # Sum across all model parallel GPUs. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_process_group) @@ -1612,7 +1608,6 @@ class FP16_DeepSpeedZeroOptimizer(object): # First compute norm for all group so we know if there is overflow self.check_overflow() - OPTIMIZER_ALLGATHER = 'optimizer_allgather' OPTIMIZER_GRADIENTS = 'optimizer_gradients' OPTIMIZER_STEP = 'optimizer_step' @@ -1638,7 +1633,7 @@ class FP16_DeepSpeedZeroOptimizer(object): norm_groups = [] single_partition_grad_groups = [] skip = False - for i, group in enumerate(self.fp16_groups): + for i, group in enumerate(self.bit16_groups): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) if self.cpu_offload: norm_groups.append( @@ -1687,15 +1682,15 @@ class FP16_DeepSpeedZeroOptimizer(object): if self.deepspeed_adam_offload: from deepspeed.ops.adam import DeepSpeedCPUAdam if type(self.optimizer) == DeepSpeedCPUAdam and self.dtype == torch.half: - fp16_param_groups = [ - fp16_partitions[partition_id] - for fp16_partitions in self.parallel_partitioned_fp16_groups + bit16_param_groups = [ + bit16_partitions[partition_id] + for bit16_partitions in self.parallel_partitioned_bit16_groups ] - self.optimizer.step(fp16_param_groups=fp16_param_groups) + self.optimizer.step(fp16_param_groups=bit16_param_groups) else: self.optimizer.step() - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): + bit16_partitions[partition_id].data.copy_(fp32_partition.data) else: self.optimizer.step() @@ -1704,8 +1699,8 @@ class FP16_DeepSpeedZeroOptimizer(object): for group in self.single_partition_of_fp32_groups: group.grad = None # in step - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + for bit16_partitions, fp32_partition in zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups): + bit16_partitions[partition_id].data.copy_(fp32_partition.data) self.stop_timers([OPTIMIZER_STEP]) @@ -1714,7 +1709,7 @@ class FP16_DeepSpeedZeroOptimizer(object): self.start_timers([OPTIMIZER_ALLGATHER]) # gather the updated weights from everyone - for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups): + for group_id, partitioned_params in enumerate(self.parallel_partitioned_bit16_groups): # Sequential AllGather Best of both worlds dp_world_size = dist.get_world_size( @@ -1742,7 +1737,6 @@ class FP16_DeepSpeedZeroOptimizer(object): shard_id * shard_size, num_elements).detach() shard_list.append(curr_shard) - dist.all_gather(shard_list, shard_list[partition_id], group=self.real_dp_process_group[group_id]) @@ -1750,7 +1744,7 @@ class FP16_DeepSpeedZeroOptimizer(object): # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): - self._update_model_fp16_weights(i) + self._update_model_bit16_weights(i) self.log_timers(timer_names) see_memory_usage('After zero_optimizer step') @@ -1797,7 +1791,7 @@ class FP16_DeepSpeedZeroOptimizer(object): return False def has_overflow_partitioned_grads_serial(self): - for i in range(len(self.fp16_groups)): + for i in range(len(self.bit16_groups)): for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True @@ -1816,7 +1810,7 @@ class FP16_DeepSpeedZeroOptimizer(object): else: params = [] - for group in self.fp16_groups: + for group in self.bit16_groups: for param in group: params.append(param) @@ -2010,15 +2004,15 @@ class FP16_DeepSpeedZeroOptimizer(object): for current, saved in zip(self.single_partition_of_fp32_groups, merged_single_partition_of_fp32_groups): current.data.copy_(saved.data) - # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): - for group_id, (fp16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups)): + # Restore base optimizer fp32 weights from ZeRO fp16 or bfloat16 weights + def _restore_from_bit16_weights(self): + for group_id, (bit16_partitions, fp32_partition) in enumerate(zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) - fp32_partition.data.copy_(fp16_partitions[partition_id].data) + fp32_partition.data.copy_(bit16_partitions[partition_id].data) - # Refresh the fp32 master params from the fp16 copies. + # Refresh the fp32 master params from the fp16 or bfloat16 copies. def refresh_fp32_params(self): - self._restore_from_fp16_weights() + self._restore_from_bit16_weights() # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): @@ -2146,7 +2140,7 @@ class FP16_DeepSpeedZeroOptimizer(object): if load_from_fp32_weights: self._restore_from_fp32_weights(state_dict_list) else: - self._restore_from_fp16_weights() + self._restore_from_bit16_weights() def _handle_overflow(cpu_sum, x, i): diff --git a/tests/unit/test_bf16.py b/tests/unit/test_bf16.py new file mode 100644 index 00000000..9220ce7e --- /dev/null +++ b/tests/unit/test_bf16.py @@ -0,0 +1,321 @@ +import math +import torch +import deepspeed +import pytest +from deepspeed.ops.adam import FusedAdam +from common import distributed_test +from deepspeed.ops.op_builder import CPUAdamBuilder +from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from util import bf16_required_version_check + + +@pytest.mark.parametrize('zero_stage, use_cpu_offload', [(2, False)]) +def test_adam_bf16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "scheduler": { + "type": "OneCycle", + "params": { + "cycle_first_step_size": 16000, + "cycle_first_stair_count": 8000, + "decay_step_size": 16000, + "cycle_min_lr": 1e-06, + "cycle_max_lr": 3e-05, + "decay_lr_rate": 1e-07, + "cycle_min_mom": 0.85, + "cycle_max_mom": 0.99, + "decay_mom_rate": 0.0 + } + }, + "fp16": { + "enabled": False + }, + "bfloat16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "cpu_offload": use_cpu_offload + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + @distributed_test(world_size=[1]) + def _test_adam_bf16_zero_onecycle_compatibility(args, zero_stage, hidden_dim): + model = SimpleModel(hidden_dim) + + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adam_bf16_zero_onecycle_compatibility(args=args, + zero_stage=zero_stage, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('zero_stage, use_cpu_offload', [(2, False)]) +def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + config_dict = { + "train_batch_size": 4, + "steps_per_print": 1, + "fp16": { + "enabled": False, + }, + "bfloat16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage, + "cpu_offload": use_cpu_offload + }, + "zero_allow_untested_optimizer": False + } + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[1]) + def _test_zero_allow_untested_optimizer(args, zero_stage): + hidden_dim = 10 + model = SimpleModel(hidden_dim) + optimizer = SimpleOptimizer(model.parameters()) + with pytest.raises(AssertionError): + model, optim, _, _ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer, + model_parameters=model.parameters()) + + _test_zero_allow_untested_optimizer(args, zero_stage) + + +@pytest.mark.parametrize('zero_stage, use_cpu_offload', [(2, False)]) +def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + if use_cpu_offload and not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + if zero_stage == 3: + pytest.skip("skip for now") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "fp16": { + "enabled": False + }, + "bfloat16": { + "enabled": True + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + "cpu_offload": use_cpu_offload, + "reduce_bucket_size": 100, + "allgather_bucket_size": 100 + } + } + args = args_from_dict(tmpdir, config_dict) + + @distributed_test(world_size=[3]) + def _test_zero_empty_partition(args, zero_stage): + hidden_dim = 1 + model = SimpleModel(hidden_dim) + + # Ensure model has 2 parameters, to cause empty partition with DP=3 + assert len(list(model.parameters())) == 2 + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + # Now make sure things work.. + data_loader = random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_zero_empty_partition(args=args, zero_stage=zero_stage) + + +@pytest.mark.parametrize('zero_stage, optimizer_constructor', + [(2, + torch.optim.Adam), + (2, + FusedAdam)]) +def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "fp16": { + "enabled": False + }, + "bfloat16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + @distributed_test(world_size=[1]) + def _test_zero_supported_client_optimizer(args, zero_stage, optimizer_constructor): + model = SimpleModel(hidden_dim) + + client_optimizer = optimizer_constructor(params=model.parameters()) + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + optimizer=client_optimizer) + + _test_zero_supported_client_optimizer(args=args, + zero_stage=zero_stage, + optimizer_constructor=optimizer_constructor) + + +def test_zero2_reduce_scatter_off(tmpdir): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 2, + "contiguous_gradients": True, + "allgather_bucket_size": 2000000000, + "reduce_bucket_size": 200000000, + "overlap_comm": False, + "reduce_scatter": False + }, + "fp16": { + "enabled": False + }, + "bfloat16": { + "enabled": True + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _helper(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _helper(args=args, model=model, hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('stage', [2]) +def test_zero_empty_grad(tmpdir, stage): + if not bf16_required_version_check(): + pytest.skip( + " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) + + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "fp16": { + "enabled": False + }, + "bfloat16": { + "enabled": True + }, + "zero_optimization": { + "stage": stage + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1]) + def _go(args, model, hidden_dim): + optimizer = torch.optim.Adam(model.parameters()) + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + optimizer=optimizer) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.bfloat16) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _go(args=args, model=model, hidden_dim=hidden_dim) diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 173e60e2..5aa94f65 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -427,7 +427,8 @@ def test_partition_nccl_alignment(tmpdir, zero_stage, world_size): # get nccl all-gather send buffers alignment factor nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor - for data_parallel_partitions in model.optimizer.parallel_partitioned_fp16_groups: + parallel_partitioned_bit16_groups = model.optimizer.parallel_partitioned_bit16_groups if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups + for data_parallel_partitions in parallel_partitioned_bit16_groups: for partition_id, partitioned_data in enumerate(data_parallel_partitions): # verify that data partition start locations are 4-byte aligned assert (partitioned_data.data_ptr() % diff --git a/tests/unit/util.py b/tests/unit/util.py index c262e0f5..966733b1 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -1,4 +1,5 @@ import torch +from deepspeed.git_version_info import torch_info def required_torch_version(): @@ -9,3 +10,23 @@ def required_torch_version(): return True else: return False + + +def bf16_required_version_check(): + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + + if type(torch.cuda.nccl.version()) != tuple: + return False + else: + NCCL_MAJOR = torch.cuda.nccl.version()[0] + NCCL_MINOR = torch.cuda.nccl.version()[1] + + CUDA_MAJOR = int(torch_info['cuda_version'].split('.')[0]) + if (TORCH_MAJOR > 1 or + (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and ( + NCCL_MAJOR > 2 or + (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch.cuda.is_bf16_supported(): + return True + else: + return False -- GitLab