未验证 提交 f6903190 编写于 作者: J Joe Mayer 提交者: GitHub

Simplify Gradient Attribute Names (#4214)

* name changes

* formatting changes
上级 9647ea79
......@@ -286,9 +286,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
else:
self.use_separate_grad_accum = False
if self.use_separate_grad_accum and not self.partition_gradients:
self.use_grad_accum_for_reduction = True
self.use_grad_accum_attribute = True
else:
self.use_grad_accum_for_reduction = False
self.use_grad_accum_attribute = False
self.round_robin_bit16_groups = []
self.round_robin_bit16_indices = []
......@@ -828,7 +828,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def overlapping_partition_gradients_reduce_epilogue(self):
self.independent_gradient_partition_epilogue()
def update_separate_grad_accum(self):
def fill_grad_accum_attribute(self):
for group in self.bit16_groups:
for param in group:
if param.grad is not None:
......@@ -839,20 +839,18 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape))
param.grad = None
def set_grad_accum_pointer(self):
for group in self.bit16_groups:
for param in group:
param.grad_accum = param.grad
def get_gradient_for_reduction(self, param):
if self.use_grad_accum_for_reduction:
if self.use_grad_accum_attribute:
return param.grad_accum.to(self.dtype) if param.grad_accum is not None else None
else:
return param.grad
def get_param_gradient_attribute(self, param):
return param.grad_accum if self.use_grad_accum_attribute else param.grad
# Clear the tensor the reduction gradient attribute is pointing to
def clear_grad_reduc_pointer(self, param):
if self.use_grad_accum_for_reduction:
def clear_grad_attribute(self, param):
if self.use_grad_accum_attribute:
param.grad_accum = None
else:
param.grad = None
......@@ -1086,7 +1084,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
current_offset += num_elements
def update_overflow_tracker_for_param_grad(self, param):
if param.grad_accum is not None and self._has_inf_or_nan(param.grad_accum.data):
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is not None and self._has_inf_or_nan(grad_accum.data):
self.local_overflow = True
def _get_offload_gradient_dict(self):
......@@ -1117,22 +1116,24 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
#accumulate gradients into param.grad_accum or parts of it that belongs to this partition
def accumulate_gradients():
grad_accum = self.get_param_gradient_attribute(param)
if not self.fp16_master_weights_and_gradients:
dest_buffer.copy_(self.accumulated_grads_in_cpu[param_id].view(-1), non_blocking=True)
param.grad_accum.data.view(-1).add_(dest_buffer)
grad_accum.data.view(-1).add_(dest_buffer)
else:
dest_buffer.narrow(0, source_offset,
num_elements).copy_(self.accumulated_grads_in_cpu[param_id].view(-1),
non_blocking=True)
param.grad_accum.data.view(-1).narrow(0, source_offset, num_elements).add_(
dest_buffer.narrow(0, source_offset, num_elements))
grad_accum.data.view(-1).narrow(0, source_offset,
num_elements).add_(dest_buffer.narrow(0, source_offset, num_elements))
#move accumulated gradients back to CPU
def copy_gradients_to_cpu():
grad_accum = self.get_param_gradient_attribute(param)
if not self.fp16_master_weights_and_gradients:
self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad_accum.data.view(-1), non_blocking=True)
self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1), non_blocking=True)
else:
self.accumulated_grads_in_cpu[param_id].data.copy_(param.grad_accum.data.view(-1).narrow(
self.accumulated_grads_in_cpu[param_id].data.copy_(grad_accum.data.view(-1).narrow(
0, source_offset, num_elements),
non_blocking=True)
......@@ -1148,8 +1149,9 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def set_norm_for_param_grad(self, param):
param_id = self.get_param_id(param)
grad_accum = self.get_param_gradient_attribute(param)
accumulated_grad = self.accumulated_grads_in_cpu[
param_id] if self.gradient_accumulation_steps > 1 else param.grad_accum
param_id] if self.gradient_accumulation_steps > 1 else grad_accum
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
......@@ -1160,10 +1162,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def set_norm_for_param_grad_in_gpu(self, param):
param_id = self.get_param_id(param)
if param.grad_accum is None:
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
accumulated_grad = param.grad
else:
accumulated_grad = param.grad_accum
accumulated_grad = grad_accum
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
......@@ -1179,10 +1182,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
if param.grad_accum is None:
src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements)
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is None:
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
else:
src_tensor = param.grad_accum.view(-1).narrow(0, source_offset, num_elements)
src_tensor = grad_accum.view(-1).narrow(0, source_offset, num_elements)
if not self.fp16_master_weights_and_gradients:
src_tensor = src_tensor.float()
......@@ -1314,7 +1318,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.previous_reduced_grads = []
self.previous_reduced_grads.append(param)
else:
self.clear_grad_reduc_pointer(param)
self.clear_grad_attribute(param)
elif self.contiguous_gradients:
self.copy_grads_in_partition(param)
else: # zero stage 1 - partition only optimizer state
......@@ -1425,7 +1429,7 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
def _clear_previous_reduced_grads(self):
if self.previous_reduced_grads is not None:
for param in self.previous_reduced_grads:
self.clear_grad_reduc_pointer(param)
self.clear_grad_attribute(param)
self.previous_reduced_grads = None
# if rank is specified do a reduction instead of an allreduce
......@@ -1605,10 +1609,11 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
current_size = 0
for i, tensor in enumerate(tensor_list):
if tensor.grad_accum is None:
tensor.grad_accum = torch.zeros_like(tensor, dtype=dtype)
grad_accum = self.get_param_gradient_attribute(tensor)
if grad_accum is None:
grad_accum = torch.zeros_like(tensor, dtype=dtype)
tensor = tensor.grad_accum
tensor = grad_accum
num_elements = tensor.numel()
tensor_offset = 0
......@@ -1953,10 +1958,8 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
# Only for Stage 1, Mode 2
if self.use_grad_accum_for_reduction:
self.update_separate_grad_accum()
else:
self.set_grad_accum_pointer()
if self.use_grad_accum_attribute:
self.fill_grad_accum_attribute()
def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册