未验证 提交 9f7126fc 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Refactor moe/non-moe gradient reduction (#1811)

上级 60fc06c6
......@@ -2074,8 +2074,8 @@ class DeepSpeedEngine(Module):
if len(small_bucket) > 0:
self.allreduce_and_copy(small_bucket, dp_group)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
grads = []
def _get_gradients_for_reduction(self):
non_expert_grads = []
expert_grads = {}
if self.has_moe_layers:
for key in self.expert_data_parallel_group.keys():
......@@ -2091,23 +2091,19 @@ class DeepSpeedEngine(Module):
param.grad = torch.zeros(param.size(),
dtype=param.dtype,
device=param.device)
if is_moe_param(param):
expert_grads[param.group_name].append(param.grad.data)
else:
grads.append(param.grad.data)
grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
grad_data = SparseTensor(grad_data)
if is_moe_param(param):
expert_grads[param.group_name].append(grad_data)
else:
grad_data = param.grad.data
if param_name in self.sparse_tensor_module_names or grad_data.is_sparse:
if is_moe_param(param):
expert_grads[param.group_name].append(SparseTensor(grad_data))
else:
grads.append(SparseTensor(grad_data))
else:
if is_moe_param(param):
expert_grads[param.group_name].append(grad_data)
else:
grads.append(grad_data)
non_expert_grads.append(grad_data)
return non_expert_grads, expert_grads
def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
split_buckets = split_half_float_double_sparse(grads)
for _, bucket_tuple in enumerate(split_buckets):
bucket_type, bucket = bucket_tuple
......@@ -2124,21 +2120,29 @@ class DeepSpeedEngine(Module):
dp_group=dp_group,
numel_per_bucket=elements_per_buffer)
def _reduce_expert_gradients(self, expert_grads, elements_per_buffer):
for ep_name, expert_grads_group in expert_grads.items():
expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(
bucket,
groups._get_expert_data_parallel_group(ep_name))
else:
# Separate between diff groups
self.allreduce_no_retain(
bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
non_expert_grads, expert_grads = self._get_gradients_for_reduction()
self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer)
if self.has_moe_layers:
for ep_name, expert_grads_group in expert_grads.items():
expert_split_buckets = split_half_float_double_sparse(expert_grads_group)
for i, bucket_tuple in enumerate(expert_split_buckets):
bucket_type, bucket = bucket_tuple
if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(
bucket,
groups._get_expert_data_parallel_group(ep_name))
else:
# Separate between diff groups
self.allreduce_no_retain(
bucket,
dp_group=groups._get_expert_data_parallel_group(ep_name),
numel_per_bucket=elements_per_buffer)
self._reduce_expert_gradients(expert_grads, elements_per_buffer)
def sparse_allreduce_no_retain(self, bucket, dp_group):
allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册