未验证 提交 5a8dd0c5 编写于 作者: K Kohaku-Blueleaf 提交者: GitHub

Fix rescale

上级 90441294
......@@ -40,7 +40,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
self.rescale = weight.w.get('rescale', None)
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册