Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
853e21d9
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1748
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
853e21d9
编写于
10月 18, 2023
作者:
V
v0xie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
faster by using cached R in forward
上级
1c6efdbb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
3 deletion
+14
-3
extensions-builtin/Lora/network_oft.py
extensions-builtin/Lora/network_oft.py
+14
-3
未找到文件。
extensions-builtin/Lora/network_oft.py
浏览文件 @
853e21d9
...
@@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule):
...
@@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule):
return
R
return
R
def
calc_updown
(
self
,
orig_weight
):
def
calc_updown
(
self
,
orig_weight
):
# this works
R
=
self
.
R
R
=
self
.
R
# this causes major deepfrying i.e. just doesn't work
# R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
if
orig_weight
.
dim
()
==
4
:
if
orig_weight
.
dim
()
==
4
:
weight
=
torch
.
einsum
(
"oihw, op -> pihw"
,
orig_weight
,
R
)
weight
=
torch
.
einsum
(
"oihw, op -> pihw"
,
orig_weight
,
R
)
else
:
else
:
weight
=
torch
.
einsum
(
"oi, op -> pi"
,
orig_weight
,
R
)
weight
=
torch
.
einsum
(
"oi, op -> pi"
,
orig_weight
,
R
)
updown
=
orig_weight
@
R
updown
=
orig_weight
@
R
output_shape
=
[
orig_weight
.
size
(
0
),
R
.
size
(
1
)]
output_shape
=
self
.
oft_blocks
.
shape
#output_shape = [R.size(0), orig_weight.size(1)]
## this works
# updown = orig_weight @ R
# output_shape = [orig_weight.size(0), R.size(1)]
return
self
.
finalize_updown
(
updown
,
orig_weight
,
output_shape
)
return
self
.
finalize_updown
(
updown
,
orig_weight
,
output_shape
)
def
forward
(
self
,
x
,
y
=
None
):
def
forward
(
self
,
x
,
y
=
None
):
x
=
self
.
org_forward
(
x
)
x
=
self
.
org_forward
(
x
)
if
self
.
multiplier
()
==
0.0
:
if
self
.
multiplier
()
==
0.0
:
return
x
return
x
R
=
self
.
get_weight
().
to
(
x
.
device
,
dtype
=
x
.
dtype
)
#R = self.get_weight().to(x.device, dtype=x.dtype)
R
=
self
.
R
.
to
(
x
.
device
,
dtype
=
x
.
dtype
)
if
x
.
dim
()
==
4
:
if
x
.
dim
()
==
4
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
x
=
torch
.
matmul
(
x
,
R
)
x
=
torch
.
matmul
(
x
,
R
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录