From 4354c3cc675c9c59a4f1edd608371322f1986c38 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 5 Jan 2022 11:38:40 -0800 Subject: [PATCH] Fix largest param numel calculation (#1623) Co-authored-by: Jeff Rasley --- deepspeed/runtime/zero/stage3.py | 76 ++------------------------------ 1 file changed, 4 insertions(+), 72 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9ea092ae..aa1b9eba 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -848,8 +848,10 @@ class DeepSpeedZeroOptimizer_Stage3(object): #Largest partitioned param largest_partitioned_param_numel = max([ - max([tensor.numel() for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups + max([ + max(tensor.numel(), + tensor.ds_numel) for tensor in fp16_partitioned_group + ]) for fp16_partitioned_group in self.fp16_partitioned_groups ]) print_rank_0( f'Largest partitioned param numel = {largest_partitioned_param_numel}', @@ -982,76 +984,6 @@ class DeepSpeedZeroOptimizer_Stage3(object): dtype=torch.float32, timers=self.timers) - def _create_fp16_partitions(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param group {i}", force=False) - - if not self.offload_param: - see_memory_usage(f"Before moving param group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param group {i} to CPU", force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param group {i} to GPU", - force=False) - else: - #Without the detach, seems like the flattening becomes part of the - #model graph causing errors downstream - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size( - group=self.dp_process_group)).detach().pin_memory()) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #set model fp16 weight to slices of flattened buffer - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): - partitioned_param.data = q.data - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): '''If flat buffer is None then the parameters in the param_list are not copied to the flat buffer. This is because they excede the number of max_params_in_cpu -- GitLab