Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
f6c8201e
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1747
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f6c8201e
编写于
11月 03, 2023
作者:
V
v0xie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb
上级
fe1967a4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
77 addition
and
101 deletion
+77
-101
extensions-builtin/Lora/lyco_helpers.py
extensions-builtin/Lora/lyco_helpers.py
+47
-0
extensions-builtin/Lora/network_oft.py
extensions-builtin/Lora/network_oft.py
+30
-101
未找到文件。
extensions-builtin/Lora/lyco_helpers.py
浏览文件 @
f6c8201e
...
...
@@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up
=
up
.
reshape
(
up
.
size
(
0
),
-
1
)
down
=
down
.
reshape
(
down
.
size
(
0
),
-
1
)
return
torch
.
einsum
(
'n m k l, i n, m j -> i j k l'
,
mid
,
up
,
down
)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def
factorization
(
dimension
:
int
,
factor
:
int
=-
1
)
->
tuple
[
int
,
int
]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if
factor
>
0
and
(
dimension
%
factor
)
==
0
:
m
=
factor
n
=
dimension
//
factor
if
m
>
n
:
n
,
m
=
m
,
n
return
m
,
n
if
factor
<
0
:
factor
=
dimension
m
,
n
=
1
,
dimension
length
=
m
+
n
while
m
<
n
:
new_m
=
m
+
1
while
dimension
%
new_m
!=
0
:
new_m
+=
1
new_n
=
dimension
//
new_m
if
new_m
+
new_n
>
length
or
new_m
>
factor
:
break
else
:
m
,
n
=
new_m
,
new_n
if
m
>
n
:
n
,
m
=
m
,
n
return
m
,
n
extensions-builtin/Lora/network_oft.py
浏览文件 @
f6c8201e
import
torch
import
network
from
lyco_helpers
import
factorization
from
einops
import
rearrange
from
modules
import
devices
class
ModuleTypeOFT
(
network
.
ModuleType
):
...
...
@@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType):
return
None
# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class
NetworkModuleOFT
(
network
.
NetworkModule
):
def
__init__
(
self
,
net
:
network
.
Network
,
weights
:
network
.
NetworkWeights
):
...
...
@@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule):
self
.
lin_module
=
None
self
.
org_module
:
list
[
torch
.
Module
]
=
[
self
.
sd_module
]
# kohya-ss
if
"oft_blocks"
in
weights
.
w
.
keys
():
self
.
is_kohya
=
True
...
...
@@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule):
is_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Linear
,
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
]
is_conv
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Conv2d
]
is_other_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
MultiheadAttention
]
#if "Linear" in self.sd_module.__class__.__name__ or is_linear:
is_other_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
MultiheadAttention
]
if
is_linear
:
self
.
out_dim
=
self
.
sd_module
.
out_features
#elif hasattr(self.sd_module, "embed_dim"):
# self.out_dim = self.sd_module.embed_dim
#else:
# raise ValueError("Linear sd_module must have out_features or embed_dim")
elif
is_other_linear
:
self
.
out_dim
=
self
.
sd_module
.
embed_dim
#self.org_weight = self.org_module[0].weight
# if hasattr(self.sd_module, "in_proj_weight"):
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
# if hasattr(self.sd_module, "out_proj_weight"):
# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
elif
is_conv
:
self
.
out_dim
=
self
.
sd_module
.
out_channels
else
:
raise
ValueError
(
"sd_module must be Linear or Conv"
)
if
self
.
is_kohya
:
self
.
num_blocks
=
self
.
dim
self
.
block_size
=
self
.
out_dim
//
self
.
num_blocks
self
.
constraint
=
self
.
alpha
*
self
.
out_dim
#elif is_linear or is_conv:
else
:
self
.
block_size
,
self
.
num_blocks
=
factorization
(
self
.
out_dim
,
self
.
dim
)
self
.
constraint
=
None
# if is_other_linear:
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
# module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
# with torch.no_grad():
# if weight.shape != module.weight.shape:
# weight = weight.reshape(module.weight.shape)
# module.weight.copy_(weight)
# module.to(device=devices.cpu, dtype=devices.dtype)
# module.weight.requires_grad_(False)
# self.lin_module = module
#return module
def
merge_weight
(
self
,
R_weight
,
org_weight
):
R_weight
=
R_weight
.
to
(
org_weight
.
device
,
dtype
=
org_weight
.
dtype
)
if
org_weight
.
dim
()
==
4
:
weight
=
torch
.
einsum
(
"oihw, op -> pihw"
,
org_weight
,
R_weight
)
else
:
weight
=
torch
.
einsum
(
"oi, op -> pi"
,
org_weight
,
R_weight
)
#weight = torch.einsum(
# "k n m, k n ... -> k m ...",
# self.oft_diag * scale + torch.eye(self.block_size, device=device),
# org_weight
#)
return
weight
def
get_weight
(
self
,
oft_blocks
,
multiplier
=
None
):
...
...
@@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule):
block_R_weighted
=
multiplier
*
block_R
+
(
1
-
multiplier
)
*
m_I
R
=
torch
.
block_diag
(
*
block_R_weighted
)
return
R
#return self.oft_blocks
def
calc_updown_kohya
(
self
,
orig_weight
,
multiplier
):
R
=
self
.
get_weight
(
self
.
oft_blocks
,
multiplier
)
merged_weight
=
self
.
merge_weight
(
R
,
orig_weight
)
def
calc_updown
(
self
,
orig_weight
):
multiplier
=
self
.
multiplier
()
*
self
.
calc_scale
()
is_other_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
MultiheadAttention
]
if
self
.
is_kohya
and
not
is_other_linear
:
R
=
self
.
get_weight
(
self
.
oft_blocks
,
multiplier
)
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
merged_weight
=
self
.
merge_weight
(
R
,
orig_weight
)
elif
not
self
.
is_kohya
and
not
is_other_linear
:
updown
=
merged_weight
.
to
(
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
-
orig_weight
output_shape
=
orig_weight
.
shape
orig_weight
=
orig_weight
return
self
.
finalize_updown
(
updown
,
orig_weight
,
output_shape
)
def
calc_updown_kb
(
self
,
orig_weight
,
multiplier
):
is_other_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
MultiheadAttention
]
if
not
is_other_linear
:
if
is_other_linear
and
orig_weight
.
shape
[
0
]
!=
orig_weight
.
shape
[
1
]:
orig_weight
=
orig_weight
.
permute
(
1
,
0
)
R
=
self
.
oft_blocks
.
to
(
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
merged_weight
=
rearrange
(
orig_weight
,
'(k n) ... -> k n ...'
,
k
=
self
.
num_blocks
,
n
=
self
.
block_size
)
#orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
merged_weight
=
torch
.
einsum
(
'k n m, k n ... -> k m ...'
,
R
*
multiplier
+
torch
.
eye
(
self
.
block_size
,
device
=
orig_weight
.
device
),
merged_weight
merged_weight
)
merged_weight
=
rearrange
(
merged_weight
,
'k m ... -> (k m) ...'
)
if
is_other_linear
and
orig_weight
.
shape
[
0
]
!=
orig_weight
.
shape
[
1
]:
orig_weight
=
orig_weight
.
permute
(
1
,
0
)
#merged_weight=merged_weight.permute(1, 0)
updown
=
merged_weight
.
to
(
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
-
orig_weight
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape
=
orig_weight
.
shape
else
:
#
skip
for now
#
FIXME: skip MultiheadAttention
for now
updown
=
torch
.
zeros
([
orig_weight
.
shape
[
1
],
orig_weight
.
shape
[
1
]],
device
=
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
output_shape
=
(
orig_weight
.
shape
[
1
],
orig_weight
.
shape
[
1
])
#if self.lin_module is not None:
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
#else:
orig_weight
=
orig_weight
return
self
.
finalize_updown
(
updown
,
orig_weight
,
output_shape
)
def
calc_updown
(
self
,
orig_weight
):
multiplier
=
self
.
multiplier
()
*
self
.
calc_scale
()
if
self
.
is_kohya
:
return
self
.
calc_updown_kohya
(
orig_weight
,
multiplier
)
else
:
return
self
.
calc_updown_kb
(
orig_weight
,
multiplier
)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def
finalize_updown
(
self
,
updown
,
orig_weight
,
output_shape
,
ex_bias
=
None
):
#return super().finalize_updown(updown, orig_weight, output_shape, ex_bias)
...
...
@@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule):
ex_bias
=
ex_bias
*
self
.
multiplier
()
return
updown
,
ex_bias
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def
factorization
(
dimension
:
int
,
factor
:
int
=-
1
)
->
tuple
[
int
,
int
]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if
factor
>
0
and
(
dimension
%
factor
)
==
0
:
m
=
factor
n
=
dimension
//
factor
if
m
>
n
:
n
,
m
=
m
,
n
return
m
,
n
if
factor
<
0
:
factor
=
dimension
m
,
n
=
1
,
dimension
length
=
m
+
n
while
m
<
n
:
new_m
=
m
+
1
while
dimension
%
new_m
!=
0
:
new_m
+=
1
new_n
=
dimension
//
new_m
if
new_m
+
new_n
>
length
or
new_m
>
factor
:
break
else
:
m
,
n
=
new_m
,
new_n
if
m
>
n
:
n
,
m
=
m
,
n
return
m
,
n
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录