提交 e407d1af 编写于 作者: A AUTOMATIC

add support for loras trained on kohya's scripts 0.4.0 (alphas)

上级 e8c3d03f
...@@ -92,6 +92,15 @@ def load_lora(name, filename): ...@@ -92,6 +92,15 @@ def load_lora(name, filename):
keys_failed_to_match.append(key_diffusers) keys_failed_to_match.append(key_diffusers)
continue continue
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha":
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear: if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d: elif type(sd_module) == torch.nn.Conv2d:
...@@ -104,17 +113,12 @@ def load_lora(name, filename): ...@@ -104,17 +113,12 @@ def load_lora(name, filename):
module.to(device=devices.device, dtype=devices.dtype) module.to(device=devices.device, dtype=devices.dtype)
lora_module = lora.modules.get(key, None)
if lora_module is None:
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "lora_up.weight": if lora_key == "lora_up.weight":
lora_module.up = module lora_module.up = module
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0: if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
...@@ -161,7 +165,7 @@ def lora_forward(module, input, res): ...@@ -161,7 +165,7 @@ def lora_forward(module, input, res):
for lora in loaded_loras: for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(lora_layer_name, None)
if module is not None: if module is not None:
res = res + module.up(module.down(input)) * lora.multiplier res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1]
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册