diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 20d6168b6b797e22745f5dffc5a289a320a355d2..32f61770a69554be1db7c56b6f9ea79b27bd5b27 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1118,7 +1118,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): if safe_mode: assert_ints_same_as_other_ranks([p.ds_id for p in self.params_in_ipg_bucket]) - if self.contiguous_gradients and not self.reduce_scatter: + if self.contiguous_gradients and self.elements_in_ipg_bucket <= self.reduce_bucket_size and not self.reduce_scatter: grad_bucket = self.__ipg_bucket_flat_buffer.narrow(0, 0, self.elements_in_ipg_bucket) grad_partitions = self.__avg_scatter_contiguous_grads(grad_bucket) else: @@ -1164,7 +1164,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): partition = buffer_to_reduce[start_offset:end_offset] if param.partition_numel() != partition.numel(): - padded_partition = torch.empty(param.partition_numel(), device=grad.device, dtype=grad.dtype) + padded_partition = torch.zeros(param.partition_numel(), device=grad.device, dtype=grad.dtype) if partition.numel() > 0: padded_partition[:partition.numel()] = partition grad_partitions.append(padded_partition) diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 28576f1f4b742b409b1e04683551ecca480f3229..85ed0cffa7c27e25062b4bb2a3170deda5b51c24 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -588,6 +588,7 @@ class EltwiseMultiplicationTestNetwork_List(EltwiseMultiplicationTestNetwork_Dic @pytest.mark.parametrize("offload_optimizer", [True, False]) @pytest.mark.parametrize("zero_grad", [True, False]) @pytest.mark.parametrize("prefetching", [True, False]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) @pytest.mark.parametrize("model_class", [ EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple, EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple, @@ -604,6 +605,7 @@ class TestZero3ParamPartitioningBase(DistributedTest): offload_optimizer: bool, zero_grad: bool, prefetching: bool, + reduce_scatter: bool, model_class: EltwiseMultiplicationTestNetwork_Dict, ) -> None: if offload_optimizer and not contiguous_gradients: @@ -621,7 +623,8 @@ class TestZero3ParamPartitioningBase(DistributedTest): "stage3_max_reuse_distance": 0, "stage3_param_persistence_threshold": param_persistence_threshold, "contiguous_gradients": contiguous_gradients, - "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0 + "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0, + "reduce_scatter": reduce_scatter }, "optimizer": { "type": "Adam", @@ -942,6 +945,7 @@ class TestZero3InitForParentWeightInitialization(DistributedTest): @pytest.mark.parametrize("offload_optimizer", [True, False]) @pytest.mark.parametrize("zero_grad", [True, False]) @pytest.mark.parametrize("prefetching", [True, False]) +@pytest.mark.parametrize("reduce_scatter", [True, False]) @pytest.mark.parametrize("model_class", [ EltwiseMultiplicationTestNetwork_Dict, EltwiseMultiplicationTestNetwork_NamedTuple, EltwiseMultiplicationTestNetwork_namedtuple, EltwiseMultiplicationTestNetwork_Tuple, @@ -951,7 +955,8 @@ class TestZero3ParamPartitioningBaseBF16(DistributedTest): world_size = 2 def test(self, param_persistence_threshold: int, contiguous_gradients: bool, offload_optimizer: bool, - zero_grad: bool, prefetching: bool, model_class: EltwiseMultiplicationTestNetwork_Dict) -> None: + zero_grad: bool, prefetching: bool, reduce_scatter: bool, + model_class: EltwiseMultiplicationTestNetwork_Dict) -> None: if offload_optimizer and not contiguous_gradients: return @@ -967,7 +972,8 @@ class TestZero3ParamPartitioningBaseBF16(DistributedTest): "stage3_max_reuse_distance": 0, "stage3_param_persistence_threshold": param_persistence_threshold, "contiguous_gradients": contiguous_gradients, - "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0 + "stage3_prefetch_bucket_size": prefetch_bucket_size if prefetching else 0, + "reduce_scatter": reduce_scatter }, "optimizer": { "type": "Adam",