未验证 提交 d229ff17 编写于 作者: H hablb 提交者: GitHub

Zero3 Fix allreduce optimization for extra large tensor (#3832)

Grad tensors that don't fit in the bucket flat buffer are not added to it, but still added to params_in_ipg_bucket
if such tensors exists use reduce_scatter of params_in_ipg_bucket instead of allreduce. since allreduce assumes all grads are in ipg_bucket_flat_buffer.

Add test for reduce scatter=false
Fix padding to zeros instead of undefined values
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 807d1b5d
......@@ -1118,7 +1118,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
if safe_mode:
assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket])
if self.contiguous_gradients and not self.reduce_scatter:
if self.contiguous_gradients and self.elements_in_ipg_bucket <= self.reduce_bucket_size and not self.reduce_scatter:
grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket)
grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket)
else:
......@@ -1164,7 +1164,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
partition = buffer_to_reduce[start_offset:end_offset]
if param.partition_numel() != partition.numel():
padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype)
padded_partition = torch.zeros(param.partition_numel(), device=grad.device, dtype=grad.dtype)
if partition.numel() > 0:
padded_partition[:partition.numel()] = partition
grad_partitions.append(padded_partition)
......
......@@ -588,6 +588,7 @@ class EltwiseMultiplicationTestNetwork_List(EltwiseMultiplicationTestNetwork_Dic
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
@pytest.mark.parametrize("prefetching", [True, False])
@pytest.mark.parametrize("reduce_scatter", [True, False])
@pytest.mark.parametrize("model_class", [
EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple,
EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple,
......@@ -604,6 +605,7 @@ class TestZero3ParamPartitioningBase(DistributedTest):
offload_optimizer: bool,
zero_grad: bool,
prefetching: bool,
reduce_scatter: bool,
model_class: EltwiseMultiplicationTestNetwork_Dict,
) -> None:
if offload_optimizer and not contiguous_gradients:
......@@ -621,7 +623,8 @@ class TestZero3ParamPartitioningBase(DistributedTest):
"stage3_max_reuse_distance": 0,
"stage3_param_persistence_threshold": param_persistence_threshold,
"contiguous_gradients": contiguous_gradients,
"stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
"stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0,
"reduce_scatter": reduce_scatter
},
"optimizer": {
"type": "Adam",
......@@ -942,6 +945,7 @@ class TestZero3InitForParentWeightInitialization(DistributedTest):
@pytest.mark.parametrize("offload_optimizer", [True, False])
@pytest.mark.parametrize("zero_grad", [True, False])
@pytest.mark.parametrize("prefetching", [True, False])
@pytest.mark.parametrize("reduce_scatter", [True, False])
@pytest.mark.parametrize("model_class", [
EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple,
EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple,
......@@ -951,7 +955,8 @@ class TestZero3ParamPartitioningBaseBF16(DistributedTest):
world_size = 2
def test(self, param_persistence_threshold: int, contiguous_gradients: bool, offload_optimizer: bool,
zero_grad: bool, prefetching: bool, model_class: EltwiseMultiplicationTestNetwork_Dict) -> None:
zero_grad: bool, prefetching: bool, reduce_scatter: bool,
model_class: EltwiseMultiplicationTestNetwork_Dict) -> None:
if offload_optimizer and not contiguous_gradients:
return
......@@ -967,7 +972,8 @@ class TestZero3ParamPartitioningBaseBF16(DistributedTest):
"stage3_max_reuse_distance": 0,
"stage3_param_persistence_threshold": param_persistence_threshold,
"contiguous_gradients": contiguous_gradients,
"stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0
"stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0,
"reduce_scatter": reduce_scatter
},
"optimizer": {
"type": "Adam",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册