Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
900419e8
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1743
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
900419e8
编写于
2月 26, 2024
作者:
A
AUTOMATIC1111
提交者:
GitHub
2月 26, 2024
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14973 from AUTOMATIC1111/Fix-new-oft-boft
Fix the OFT/BOFT bugs when using new LyCORIS implementation
上级
18819723
c4afdb78
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
26 deletion
+24
-26
extensions-builtin/Lora/network_oft.py
extensions-builtin/Lora/network_oft.py
+24
-26
未找到文件。
extensions-builtin/Lora/network_oft.py
浏览文件 @
900419e8
import
torch
import
network
from
lyco_helpers
import
factorization
from
einops
import
rearrange
...
...
@@ -22,24 +21,24 @@ class NetworkModuleOFT(network.NetworkModule):
self
.
org_module
:
list
[
torch
.
Module
]
=
[
self
.
sd_module
]
self
.
scale
=
1.0
self
.
is_
kohya
=
False
self
.
is_
R
=
False
self
.
is_boft
=
False
# kohya-ss
# kohya-ss
/New LyCORIS OFT/BOFT
if
"oft_blocks"
in
weights
.
w
.
keys
():
self
.
is_kohya
=
True
self
.
oft_blocks
=
weights
.
w
[
"oft_blocks"
]
# (num_blocks, block_size, block_size)
self
.
alpha
=
weights
.
w
[
"alpha"
]
# alpha is constraint
self
.
alpha
=
weights
.
w
.
get
(
"alpha"
,
None
)
# alpha is constraint
self
.
dim
=
self
.
oft_blocks
.
shape
[
0
]
# lora dim
# LyCORIS OFT
#
Old
LyCORIS OFT
elif
"oft_diag"
in
weights
.
w
.
keys
():
self
.
is_R
=
True
self
.
oft_blocks
=
weights
.
w
[
"oft_diag"
]
# self.alpha is unused
self
.
dim
=
self
.
oft_blocks
.
shape
[
1
]
# (num_blocks, block_size, block_size)
# LyCORIS BOFT
if
weights
.
w
[
"oft_diag"
]
.
dim
()
==
4
:
self
.
is_boft
=
True
# LyCORIS BOFT
if
self
.
oft_blocks
.
dim
()
==
4
:
self
.
is_boft
=
True
self
.
rescale
=
weights
.
w
.
get
(
'rescale'
,
None
)
if
self
.
rescale
is
not
None
:
self
.
rescale
=
self
.
rescale
.
reshape
(
-
1
,
*
[
1
]
*
(
self
.
org_module
[
0
].
weight
.
dim
()
-
1
))
...
...
@@ -55,30 +54,29 @@ class NetworkModuleOFT(network.NetworkModule):
elif
is_other_linear
:
self
.
out_dim
=
self
.
sd_module
.
embed_dim
if
self
.
is_kohya
:
self
.
constraint
=
self
.
alpha
*
self
.
out_dim
self
.
num_blocks
=
self
.
dim
self
.
block_size
=
self
.
out_dim
//
self
.
dim
elif
self
.
is_boft
:
self
.
num_blocks
=
self
.
dim
self
.
block_size
=
self
.
out_dim
//
self
.
dim
self
.
constraint
=
(
0
if
self
.
alpha
is
None
else
self
.
alpha
)
*
self
.
out_dim
if
self
.
is_R
:
self
.
constraint
=
None
self
.
boft_m
=
weights
.
w
[
"oft_diag"
].
shape
[
0
]
self
.
block_num
=
weights
.
w
[
"oft_diag"
].
shape
[
1
]
self
.
block_size
=
weights
.
w
[
"oft_diag"
].
shape
[
2
]
self
.
block_size
=
self
.
dim
self
.
num_blocks
=
self
.
out_dim
//
self
.
dim
elif
self
.
is_boft
:
self
.
boft_m
=
self
.
oft_blocks
.
shape
[
0
]
self
.
num_blocks
=
self
.
oft_blocks
.
shape
[
1
]
self
.
block_size
=
self
.
oft_blocks
.
shape
[
2
]
self
.
boft_b
=
self
.
block_size
#self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
else
:
self
.
constraint
=
None
self
.
block_size
,
self
.
num_blocks
=
factorization
(
self
.
out_dim
,
self
.
dim
)
def
calc_updown
(
self
,
orig_weight
):
oft_blocks
=
self
.
oft_blocks
.
to
(
orig_weight
.
device
)
eye
=
torch
.
eye
(
self
.
block_size
,
device
=
oft_blocks
.
device
)
if
self
.
is_kohya
:
block_Q
=
oft_blocks
-
oft_blocks
.
transpose
(
1
,
2
)
# ensure skew-symmetric orthogonal matrix
norm_Q
=
torch
.
norm
(
block_Q
.
flatten
())
new_norm_Q
=
torch
.
clamp
(
norm_Q
,
max
=
self
.
constraint
.
to
(
oft_blocks
.
device
))
block_Q
=
block_Q
*
((
new_norm_Q
+
1e-8
)
/
(
norm_Q
+
1e-8
))
if
not
self
.
is_R
:
block_Q
=
oft_blocks
-
oft_blocks
.
transpose
(
-
1
,
-
2
)
# ensure skew-symmetric orthogonal matrix
if
self
.
constraint
!=
0
:
norm_Q
=
torch
.
norm
(
block_Q
.
flatten
())
new_norm_Q
=
torch
.
clamp
(
norm_Q
,
max
=
self
.
constraint
.
to
(
oft_blocks
.
device
))
block_Q
=
block_Q
*
((
new_norm_Q
+
1e-8
)
/
(
norm_Q
+
1e-8
))
oft_blocks
=
torch
.
matmul
(
eye
+
block_Q
,
(
eye
-
block_Q
).
float
().
inverse
())
R
=
oft_blocks
.
to
(
orig_weight
.
device
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录