From d56268f3752433d63b1025435d8c597174bf7033 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Tue, 25 Apr 2023 13:25:07 -0400 Subject: [PATCH] fixing default communication_data_type for bfloat16_enabled and docs (#3370) Co-authored-by: Alexander Jipa Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 3 +++ docs/_pages/config-json.md | 2 +- tests/unit/runtime/half_precision/test_bf16.py | 7 +++++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f073a0e6..a95b5a6c 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -778,6 +778,9 @@ class DeepSpeedEngine(Module): if self.fp16_enabled(): return torch.float16 + if self.bfloat16_enabled(): + return torch.bfloat16 + return torch.float32 def postscale_gradients(self): diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index f146ebee..84f2f833 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -181,7 +181,7 @@ Example of **scheduler** ### Communication options -**communication_data_type**: [boolean] +**communication_data_type**: [string] | Description | Default | | ----------------------------------------------------------------------------------------------------------------------------- | ------- | diff --git a/tests/unit/runtime/half_precision/test_bf16.py b/tests/unit/runtime/half_precision/test_bf16.py index 916267a6..740fa306 100644 --- a/tests/unit/runtime/half_precision/test_bf16.py +++ b/tests/unit/runtime/half_precision/test_bf16.py @@ -287,7 +287,7 @@ class TestZeroEmptyGrad(DistributedTest): @pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"]) -@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16], ids=["fp16", "bfp16"]) +@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bfp16", "default"]) class TestZeroDtypeCocktail(DistributedTest): world_size = 2 @@ -312,8 +312,11 @@ class TestZeroDtypeCocktail(DistributedTest): "zero_optimization": { "stage": 2 }, - "communication_data_type": type_str[comm_type] } + if comm_type is not None: + config_dict["communication_data_type"] = type_str[comm_type] + else: + comm_type = comp_type hidden_dim = 10 model = SimpleModel(hidden_dim) -- GitLab