未验证 提交 55ed1057 编写于 作者: J Jeff Rasley 提交者: GitHub

fix bug related to stitching reduced grads across communication partitions (#318)

上级 91b4a93d
......@@ -249,8 +249,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# RS: divide up the sub-partitions and keep track of offsets for each param
# partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size(group=self.dp_process_group)
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, \
params_not_local = self.get_all_sub_partition_info(
params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local = self.get_all_sub_partition_info(
tensor_list=self.fp16_groups[i],
all_element_intervals=element_intervals,
local_rank=local_rank,
......@@ -591,28 +590,20 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_comm_partitions.append(single_comm_all_partitions)
for p in my_params:
partitions = param_partition_map[p]
parts = []
for part in partitions:
params, offsets = partition_param_map[part]
found = False
for p_idx, _p in enumerate(params):
if p.__hash__() == _p.__hash__():
found = True
if offsets[p_idx][0] is not None:
my_part = part.narrow(0,
offsets[p_idx][0],
offsets[p_idx][1])
parts.append(my_part)
assert found
if p is not None:
updated_grad = _unflatten_dense_tensors(torch.cat(parts), [p])
p.grad.copy_(updated_grad[0])
# stitch together all rank sub partitions for each comm idx
flat_comm_grads = []
for comm_idx, rank_partitions in enumerate(all_comm_partitions):
flat_comm_grads.append(torch.cat(rank_partitions))
flat_all_grads = torch.cat(flat_comm_grads)
# copy back reduced gradients but only those needed for this local rank
for param, updated_grad in zip(self.fp16_groups[i], _unflatten_dense_tensors(flat_all_grads, self.fp16_groups[i])):
if param in my_params:
param.grad.copy_(updated_grad)
def step(self, closure=None):
# First compute norm for all group so we know if there is overflow
self.overflow = self.overflow_checker.check()
prev_scale = self.loss_scale
......@@ -649,6 +640,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
#)
#TODO RS: can we safely use dtype of the first sub-partition? i think so
# create flat gradient partitions for parameters updated by this process
local_grad_sub_partitions = self.get_flat_sub_partitions(
comm_tensor_list=self.params_in_rank_sub_partitions[i][partition_id],
comm_param_offsets=self.params_in_rank_sub_partitions_offsets[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册