Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
fce86ab7
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1748
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
fce86ab7
编写于
10月 21, 2023
作者:
V
v0xie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: support multiplier, no forward pass hook
上级
76835477
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
33 addition
and
10 deletion
+33
-10
extensions-builtin/Lora/network_oft.py
extensions-builtin/Lora/network_oft.py
+33
-10
未找到文件。
extensions-builtin/Lora/network_oft.py
浏览文件 @
fce86ab7
...
...
@@ -32,21 +32,27 @@ class NetworkModuleOFT(network.NetworkModule):
self
.
org_module
:
list
[
torch
.
Module
]
=
[
self
.
sd_module
]
self
.
org_weight
=
self
.
org_module
[
0
].
weight
.
to
(
self
.
org_module
[
0
].
weight
.
device
,
copy
=
True
)
#self.org_weight = self.org_module[0].weight.to(devices.cpu, copy=True)
self
.
R
=
self
.
get_weight
(
self
.
oft_blocks
)
init_multiplier
=
self
.
multiplier
()
*
self
.
calc_scale
()
self
.
last_multiplier
=
init_multiplier
self
.
R
=
self
.
get_weight
(
self
.
oft_blocks
,
init_multiplier
)
self
.
merged_weight
=
self
.
merge_weight
()
self
.
apply_to
()
self
.
merged
=
False
# weights_backup = getattr(self.org_module[0], 'network_weights_backup', None)
# if weights_backup is None:
# self.org_module[0].network_weights_backup = self.org_weight
def
merge_weight
(
self
):
org_sd
=
self
.
org_module
[
0
].
state_dict
()
#
org_sd = self.org_module[0].state_dict()
R
=
self
.
R
.
to
(
self
.
org_weight
.
device
,
dtype
=
self
.
org_weight
.
dtype
)
if
self
.
org_weight
.
dim
()
==
4
:
weight
=
torch
.
einsum
(
"oihw, op -> pihw"
,
self
.
org_weight
,
R
)
else
:
weight
=
torch
.
einsum
(
"oi, op -> pi"
,
self
.
org_weight
,
R
)
org_sd
[
'weight'
]
=
weight
#
org_sd['weight'] = weight
# replace weight
#self.org_module[0].load_state_dict(org_sd)
return
weight
...
...
@@ -74,6 +80,7 @@ class NetworkModuleOFT(network.NetworkModule):
self
.
org_module
[
0
].
register_forward_hook
(
self
.
forward_hook
)
def
get_weight
(
self
,
oft_blocks
,
multiplier
=
None
):
multiplier
=
multiplier
.
to
(
oft_blocks
.
device
,
dtype
=
oft_blocks
.
dtype
)
constraint
=
self
.
constraint
.
to
(
oft_blocks
.
device
,
dtype
=
oft_blocks
.
dtype
)
block_Q
=
oft_blocks
-
oft_blocks
.
transpose
(
1
,
2
)
norm_Q
=
torch
.
norm
(
block_Q
.
flatten
())
...
...
@@ -81,9 +88,9 @@ class NetworkModuleOFT(network.NetworkModule):
block_Q
=
block_Q
*
((
new_norm_Q
+
1e-8
)
/
(
norm_Q
+
1e-8
))
m_I
=
torch
.
eye
(
self
.
block_size
,
device
=
oft_blocks
.
device
).
unsqueeze
(
0
).
repeat
(
self
.
num_blocks
,
1
,
1
)
block_R
=
torch
.
matmul
(
m_I
+
block_Q
,
(
m_I
-
block_Q
).
inverse
())
#block_R_weighted = multiplier * block_R + (1 - multiplier) *
I
#
R = torch.block_diag(*block_R_weighted)
R
=
torch
.
block_diag
(
*
block_R
)
block_R_weighted
=
multiplier
*
block_R
+
(
1
-
multiplier
)
*
m_
I
R
=
torch
.
block_diag
(
*
block_R_weighted
)
#
R = torch.block_diag(*block_R)
return
R
...
...
@@ -93,6 +100,8 @@ class NetworkModuleOFT(network.NetworkModule):
#R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
##self.R = R
#R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
##self.R = R
#if orig_weight.dim() == 4:
# weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
#else:
...
...
@@ -103,19 +112,33 @@ class NetworkModuleOFT(network.NetworkModule):
updown
=
torch
.
zeros_like
(
orig_weight
,
device
=
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
#updown = orig_weight
output_shape
=
orig_weight
.
shape
#
orig_weight = self.merged_weight.to(orig_weight.device, dtype=orig_weight.dtype)
orig_weight
=
self
.
merged_weight
.
to
(
orig_weight
.
device
,
dtype
=
orig_weight
.
dtype
)
#output_shape = self.oft_blocks.shape
return
self
.
finalize_updown
(
updown
,
orig_weight
,
output_shape
)
def
pre_forward_hook
(
self
,
module
,
input
):
if
not
self
.
merged
:
multiplier
=
self
.
multiplier
()
*
self
.
calc_scale
()
if
not
multiplier
==
self
.
last_multiplier
or
not
self
.
merged
:
#if multiplier != self.last_multiplier or not self.merged:
self
.
R
=
self
.
get_weight
(
self
.
oft_blocks
,
multiplier
)
self
.
last_multiplier
=
multiplier
self
.
merged_weight
=
self
.
merge_weight
()
self
.
replace_weight
(
self
.
merged_weight
)
#elif not self.merged:
# self.replace_weight(self.merged_weight)
def
forward_hook
(
self
,
module
,
args
,
output
):
if
self
.
merged
:
pass
#output = output * self.multiplier() * self.calc_scale()
#if len(args) > 0:
# y = args[0]
# output = output + y
#return output
#if self.merged:
# pass
#self.restore_weight()
#print(f'Forward hook in {self.network_key} called')
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录