From 851c3d51eda1857a14d21ff3eb2e07c44ba262bc Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 9 Mar 2024 12:31:32 +0800 Subject: [PATCH] Fix bugs for torch.nn.MultiheadAttention --- extensions-builtin/Lora/network.py | 8 +++++++- extensions-builtin/Lora/networks.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 2268b0f7e..183f8bd7c 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -117,6 +117,12 @@ class NetworkModule: if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + elif isinstance(self.sd_module, nn.MultiheadAttention): + # For now, only self-attn use Pytorch's MHA + # So assume all qkvo proj have same shape + self.shape = self.sd_module.out_proj.weight.shape + else: + self.shape = None self.ops = None self.extra_kwargs = {} @@ -146,7 +152,7 @@ class NetworkModule: self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None self.scale = weights.w["scale"].item() if "scale" in weights.w else None - self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None + self.dora_scale = weights.w.get("dora_scale", None) self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1) def multiplier(self): diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 04bd19117..42b14dc23 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -429,9 +429,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: try: with torch.no_grad(): - updown_q, _ = module_q.calc_updown(self.in_proj_weight) - updown_k, _ = module_k.calc_updown(self.in_proj_weight) - updown_v, _ = module_v.calc_updown(self.in_proj_weight) + # Send "real" orig_weight into MHA's lora module + qw, kw, vw = self.in_proj_weight.chunk(3, 0) + updown_q, _ = module_q.calc_updown(qw) + updown_k, _ = module_k.calc_updown(kw) + updown_v, _ = module_v.calc_updown(vw) + del qw, kw, vw updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) -- GitLab