未验证 提交 4a3234e0 编写于 作者: O Olatunji Ruwase 提交者: GitHub

ZeRO-2: Handle gradients of empty partitions (#275)

* Load non-DeepSpeed checkpoints into ZeRO optimizer

* Handle parameters smaller than DP

* Formatting fixes

* Handle empty partitions

* Fix perf bug
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 97787881
......@@ -403,6 +403,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.is_grad_computed[i][partition_id] = {}
self.grad_partition_insertion_offset[i][partition_id] = {}
self.grad_start_offset[i][partition_id] = {}
self.total_grads_in_partition[i][partition_id] = 0
self.initialize_gradient_partition(i, param_group, partition_id)
self.is_partition_reduced[i][partition_id] = False
self.first_param_index_in_partition[i][
......@@ -429,6 +430,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
self.params_in_partition[i],
self.first_offset[i],
self.partition_size[i],
dtype=torch.half,
device=torch.cuda.current_device(),
return_tensor_list=True)
self._release_ipg_buffers()
......@@ -1014,6 +1017,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
tensor_list,
first_offset,
partition_size,
dtype,
device,
return_tensor_list=False):
flat_tensor_list = []
current_size = 0
......@@ -1050,8 +1055,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
if current_size < partition_size:
flat_tensor_list.append(
torch.zeros(int(partition_size - current_size),
dtype=tensor_list[0].dtype,
device=tensor_list[0].device))
dtype=dtype,
device=device))
if return_tensor_list:
return flat_tensor_list
......@@ -1135,9 +1140,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
for group in self.single_partition_of_fp32_groups:
group.grad = None
for i in range(len(norm_groups)):
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
......
......@@ -353,7 +353,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
_test_zero_allow_untested_optimizer(args)
@pytest.mark.parametrize("zero_stage", [1])
@pytest.mark.parametrize("zero_stage", [1, 2])
def test_zero_empty_partition(tmpdir, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册