未验证 提交 c4485e2f 编写于 作者: S Stas Bekman 提交者: GitHub

improve assert message (#1024)

上级 18a26e86
......@@ -588,7 +588,7 @@ class DeepSpeedEngine(Module):
assert all([param.dtype == torch.half for param in self.module.parameters()]), f"Model must initialized in fp16 mode for ZeRO Stage 3."
self.module.half()
else:
assert all([param.dtype == torch.float for param in self.module.parameters()]), f"The fp16 is not enabled but dtype on parameters not fp16"
assert all([param.dtype == torch.float for param in self.module.parameters()]), f"fp16 is not enabled but one or several model parameters have dtype of fp16"
if not self.dont_change_device:
self.module.to(self.device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册