From 3a608ead87778fb8a0b3520b95b381c49b549de7 Mon Sep 17 00:00:00 2001 From: a2569875 Date: Tue, 25 Jul 2023 23:56:55 +0800 Subject: [PATCH] fix weights --- composable_lora.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/composable_lora.py b/composable_lora.py index 1b6678b..02382a6 100644 --- a/composable_lora.py +++ b/composable_lora.py @@ -449,7 +449,12 @@ def lora_Linear_forward(self, input): torch.nn.Linear_forward_before_lora = backup_Linear_forward return result return lycoris.lyco_Linear_forward(self, input) - clear_cache_lora(self, False) + if lora_ext.is_sd_1_5: + import networks + networks.network_restore_weights_from_backup(self) + networks.network_reset_cached_weight(self) + else: + clear_cache_lora(self, False) if (not self.weight.is_cuda) and input.is_cuda: #if variables not on the same device (between cpu and gpu) self_weight_cuda = self.weight.to(device=devices.device) #pass to GPU to_del = self.weight @@ -488,7 +493,12 @@ def lora_Conv2d_forward(self, input): return result return lycoris.lyco_Conv2d_forward(self, input) - clear_cache_lora(self, False) + if lora_ext.is_sd_1_5: + import networks + networks.network_restore_weights_from_backup(self) + networks.network_reset_cached_weight(self) + else: + clear_cache_lora(self, False) if (not self.weight.is_cuda) and input.is_cuda: self_weight_cuda = self.weight.to(device=devices.device) to_del = self.weight @@ -528,7 +538,12 @@ def lora_MultiheadAttention_forward(self, input): return result return lycoris.lyco_MultiheadAttention_forward(self, input) - clear_cache_lora(self, False) + if lora_ext.is_sd_1_5: + import networks + networks.network_restore_weights_from_backup(self) + networks.network_reset_cached_weight(self) + else: + clear_cache_lora(self, False) if (not self.weight.is_cuda) and input.is_cuda: self_weight_cuda = self.weight.to(device=devices.device) to_del = self.weight -- GitLab