From 8d87c89e4246f50a3ed5c8fadf9269bd348fa83a Mon Sep 17 00:00:00 2001 From: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Date: Wed, 18 Jan 2023 12:00:25 -0800 Subject: [PATCH] BF16 optimizer for BF16+ZeRO Stage 1 (#2706) * BF16 optimizer only with ZeRO stage 1. * Updating to grad accum of fp32 for BF16 ZeRO1 case. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/engine.py | 12 ++++++++++- tests/unit/runtime/test_ds_initialize.py | 26 ++++++++++++++++-------- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0681f2ff..e91918e1 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -888,7 +888,7 @@ class DeepSpeedEngine(Module): model_dtype = torch.bfloat16 if self._config.grad_accum_dtype == None: - if model_dtype == torch.bfloat16 and not self.zero_optimization(): + if model_dtype == torch.bfloat16: grad_accum_dtype = torch.float32 else: grad_accum_dtype = model_dtype @@ -1223,6 +1223,16 @@ class DeepSpeedEngine(Module): logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + # BF16 optimizer supports stage 1 optimizations + if model_dtype == torch.bfloat16: + if grad_accum_dtype != torch.float32: + raise NotImplementedError( + "BF16 optimizer for ZeRO requires fp32 gradient accumulation") + if self.zero_optimization_stage() == 1: + return BFLOAT16 + else: + raise NotImplementedError( + "ZeRO stages 2 and 3 are not supported with the BF16 optimizer") return ZERO_OPTIMIZATION elif amp_enabled: if model_dtype != grad_accum_dtype: diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index a2b3886d..6fece7dd 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -116,14 +116,19 @@ class TestConfigOptimizer(DistributedTest): assert isinstance(ds_optimizer, FusedAdam) -@pytest.mark.parametrize('optimizer_extension', ['zero', 'amp', None]) +@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'amp', None]) @pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32']) @pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32']) class TestOptimizerImplementation(DistributedTest): world_size = 1 def test(self, optimizer_extension, model_dtype, grad_accum_dtype): - zero_stage = 1 if optimizer_extension == 'zero' else 0 + if optimizer_extension == 'zero1': + zero_stage = 1 + elif optimizer_extension == 'zero2': + zero_stage = 2 + else: + zero_stage = 0 amp = True if optimizer_extension == 'amp' else False fp16 = True if model_dtype == 'fp16' else False bf16 = True if model_dtype == 'bf16' else False @@ -164,12 +169,17 @@ class TestOptimizerImplementation(DistributedTest): # Enumerate supported configurations is_supported = {} - # Zero Wrapper - is_supported[('zero', 'fp16', None)] = True - is_supported[('zero', 'fp16', 'fp16')] = True - is_supported[('zero', 'bf16', 'bf16')] = True - is_supported[('zero', 'fp32', None)] = True - is_supported[('zero', 'fp32', 'fp32')] = True + # ZeRO 1 Wrapper + is_supported[('zero1', 'fp16', None)] = True + is_supported[('zero1', 'fp16', 'fp16')] = True + is_supported[('zero1', 'bf16', 'fp32')] = True + is_supported[('zero1', 'fp32', None)] = True + is_supported[('zero1', 'fp32', 'fp32')] = True + # ZeRO 2 Wrapper + is_supported[('zero2', 'fp16', None)] = True + is_supported[('zero2', 'fp16', 'fp16')] = True + is_supported[('zero2', 'fp32', None)] = True + is_supported[('zero2', 'fp32', 'fp32')] = True # Amp Wrapper is_supported[('amp', 'fp32', None)] = True is_supported[('amp', 'fp32', 'fp32')] = True -- GitLab