未验证 提交 8d87c89e 编写于 作者: J Joe Mayer 提交者: GitHub

BF16 optimizer for BF16+ZeRO Stage 1 (#2706)

* BF16 optimizer only with ZeRO stage 1.

* Updating to grad accum of fp32 for BF16 ZeRO1 case.
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 23e5133c
......@@ -888,7 +888,7 @@ class DeepSpeedEngine(Module):
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype == None:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
if model_dtype == torch.bfloat16:
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
......@@ -1223,6 +1223,16 @@ class DeepSpeedEngine(Module):
logger.warning(
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
# BF16 optimizer supports stage 1 optimizations
if model_dtype == torch.bfloat16:
if grad_accum_dtype != torch.float32:
raise NotImplementedError(
"BF16 optimizer for ZeRO requires fp32 gradient accumulation")
if self.zero_optimization_stage() == 1:
return BFLOAT16
else:
raise NotImplementedError(
"ZeRO stages 2 and 3 are not supported with the BF16 optimizer")
return ZERO_OPTIMIZATION
elif amp_enabled:
if model_dtype != grad_accum_dtype:
......
......@@ -116,14 +116,19 @@ class TestConfigOptimizer(DistributedTest):
assert isinstance(ds_optimizer, FusedAdam)
@pytest.mark.parametrize('optimizer_extension', ['zero', 'amp', None])
@pytest.mark.parametrize('optimizer_extension', ['zero1', 'zero2', 'amp', None])
@pytest.mark.parametrize('model_dtype', ['fp16', 'bf16', 'fp32'])
@pytest.mark.parametrize('grad_accum_dtype', [None, 'fp16', 'bf16', 'fp32'])
class TestOptimizerImplementation(DistributedTest):
world_size = 1
def test(self, optimizer_extension, model_dtype, grad_accum_dtype):
zero_stage = 1 if optimizer_extension == 'zero' else 0
if optimizer_extension == 'zero1':
zero_stage = 1
elif optimizer_extension == 'zero2':
zero_stage = 2
else:
zero_stage = 0
amp = True if optimizer_extension == 'amp' else False
fp16 = True if model_dtype == 'fp16' else False
bf16 = True if model_dtype == 'bf16' else False
......@@ -164,12 +169,17 @@ class TestOptimizerImplementation(DistributedTest):
# Enumerate supported configurations
is_supported = {}
# Zero Wrapper
is_supported[('zero', 'fp16', None)] = True
is_supported[('zero', 'fp16', 'fp16')] = True
is_supported[('zero', 'bf16', 'bf16')] = True
is_supported[('zero', 'fp32', None)] = True
is_supported[('zero', 'fp32', 'fp32')] = True
# ZeRO 1 Wrapper
is_supported[('zero1', 'fp16', None)] = True
is_supported[('zero1', 'fp16', 'fp16')] = True
is_supported[('zero1', 'bf16', 'fp32')] = True
is_supported[('zero1', 'fp32', None)] = True
is_supported[('zero1', 'fp32', 'fp32')] = True
# ZeRO 2 Wrapper
is_supported[('zero2', 'fp16', None)] = True
is_supported[('zero2', 'fp16', 'fp16')] = True
is_supported[('zero2', 'fp32', None)] = True
is_supported[('zero2', 'fp32', 'fp32')] = True
# Amp Wrapper
is_supported[('amp', 'fp32', None)] = True
is_supported[('amp', 'fp32', 'fp32')] = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册