未验证 提交 adac058a 编写于 作者: S Stas Bekman 提交者: GitHub

fix double linear override; spelling (#954)

上级 e6999ebd
......@@ -191,7 +191,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
if self.mem_efficient_linear:
print_rank_0(
f"Your linear layers are being patched with more memory efficient version. This will persit unless manually reset.",
"nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.",
force=True)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = LinearFunctionForZeroStage3.apply
......@@ -361,13 +361,6 @@ class Init(InsertPostInitMethodToModuleSubClasses):
self._convert_to_deepspeed_param(param)
param.partition()
if mem_efficient_linear:
print_rank_0(
f"Your linear layers are being patched with more memory efficient version. This will persit unless manually turned reset.",
force=True)
self.linear_bk = torch.nn.functional.linear
torch.nn.functional.linear = LinearFunctionForZeroStage3.apply
def _post_init_method(self, module):
#see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册