未验证 提交 d56268f3 编写于 作者: A Alexander Jipa 提交者: GitHub

fixing default communication_data_type for bfloat16_enabled and docs (#3370)

Co-authored-by: NAlexander Jipa <azzhipa@amazon.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
上级 39825a90
......@@ -778,6 +778,9 @@ class DeepSpeedEngine(Module):
if self.fp16_enabled():
return torch.float16
if self.bfloat16_enabled():
return torch.bfloat16
return torch.float32
def postscale_gradients(self):
......
......@@ -181,7 +181,7 @@ Example of <i>**scheduler**</i>
### Communication options
<i>**communication_data_type**</i>: [boolean]
<i>**communication_data_type**</i>: [string]
| Description | Default |
| ----------------------------------------------------------------------------------------------------------------------------- | ------- |
......
......@@ -287,7 +287,7 @@ class TestZeroEmptyGrad(DistributedTest):
@pytest.mark.parametrize("comp_type", [torch.float16, torch.bfloat16, torch.float], ids=["fp16", "bfp16", "fp32"])
@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16], ids=["fp16", "bfp16"])
@pytest.mark.parametrize("comm_type", [torch.float16, torch.bfloat16, None], ids=["fp16", "bfp16", "default"])
class TestZeroDtypeCocktail(DistributedTest):
world_size = 2
......@@ -312,8 +312,11 @@ class TestZeroDtypeCocktail(DistributedTest):
"zero_optimization": {
"stage": 2
},
"communication_data_type": type_str[comm_type]
}
if comm_type is not None:
config_dict["communication_data_type"] = type_str[comm_type]
else:
comm_type = comp_type
hidden_dim = 10
model = SimpleModel(hidden_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册