diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0681f2ffd74f3fc06c7bcd241dbefdee4c6b0253..e91918e1bc18d9113828978f406357bef53f2690 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -888,7 +888,7 @@ class DeepSpeedEngine(Module): model_dtype = torch.bfloat16 if self._config.grad_accum_dtype == None: - if model_dtype == torch.bfloat16 and not self.zero_optimization(): + if model_dtype == torch.bfloat16: grad_accum_dtype = torch.float32 else: grad_accum_dtype = model_dtype @@ -1223,6 +1223,16 @@ class DeepSpeedEngine(Module): logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + # BF16 optimizer supports stage 1 optimizations + if model_dtype == torch.bfloat16: + if grad_accum_dtype != torch.float32: + raise NotImplementedError( + "BF16 optimizer for ZeRO requires fp32 gradient accumulation") + if self.zero_optimization_stage() == 1: + return BFLOAT16 + else: + raise NotImplementedError( + "ZeRO stages 2 and 3 are not supported with the BF16 optimizer") return ZERO_OPTIMIZATION elif amp_enabled: if model_dtype != grad_accum_dtype: diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index a2b3886d373a2a1b3dd00261e63942d4a7d79b04..6fece7dd5ebc1cbf1b4f90b0aefa18d6fd547eff 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -116,14 +116,19 @@ class TestConfigOptimizer(DistributedTest): assert isinstance(ds_optimizer, FusedAdam) -@pytest.mark.parametrize('optimizer_extension', ['zero', 'amp', None]) +@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'amp', None]) @pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32']) class TestOptimizerImplementation(DistributedTest): world_size = 1 def test(self, optimizer_extension, model_dtype, grad_accum_dtype): - zero_stage = 1 if optimizer_extension == 'zero' else 0 + if optimizer_extension == 'zero1': + zero_stage = 1 + elif optimizer_extension == 'zero2': + zero_stage = 2 + else: + zero_stage = 0 amp = True if optimizer_extension == 'amp' else False fp16 = True if model_dtype == 'fp16' else False bf16 = True if model_dtype == 'bf16' else False @@ -164,12 +169,17 @@ class TestOptimizerImplementation(DistributedTest): # Enumerate supported configurations is_supported = {} - # Zero Wrapper - is_supported[('zero', 'fp16', None)] = True - is_supported[('zero', 'fp16', 'fp16')] = True - is_supported[('zero', 'bf16', 'bf16')] = True - is_supported[('zero', 'fp32', None)] = True - is_supported[('zero', 'fp32', 'fp32')] = True + # ZeRO 1 Wrapper + is_supported[('zero1', 'fp16', None)] = True + is_supported[('zero1', 'fp16', 'fp16')] = True + is_supported[('zero1', 'bf16', 'fp32')] = True + is_supported[('zero1', 'fp32', None)] = True + is_supported[('zero1', 'fp32', 'fp32')] = True + # ZeRO 2 Wrapper + is_supported[('zero2', 'fp16', None)] = True + is_supported[('zero2', 'fp16', 'fp16')] = True + is_supported[('zero2', 'fp32', None)] = True + is_supported[('zero2', 'fp32', 'fp32')] = True # Amp Wrapper is_supported[('amp', 'fp32', None)] = True is_supported[('amp', 'fp32', 'fp32')] = True