Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
3b1b1444
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1748
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
3b1b1444
编写于
9月 13, 2022
作者:
C
C43H66N12O12S2
提交者:
GitHub
9月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete cross attention update
上级
c84e3336
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
73 addition
and
1 deletion
+73
-1
modules/sd_hijack.py
modules/sd_hijack.py
+73
-1
未找到文件。
modules/sd_hijack.py
浏览文件 @
3b1b1444
...
...
@@ -11,7 +11,7 @@ from modules.shared import opts, device, cmd_opts
from
ldm.util
import
default
from
einops
import
rearrange
import
ldm.modules.attention
import
ldm.modules.diffusionmodules.model
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
...
...
@@ -100,6 +100,76 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return
self
.
to_out
(
r2
)
def
nonlinearity_hijack
(
x
):
# swish
t
=
torch
.
sigmoid
(
x
)
x
*=
t
del
t
return
x
def
cross_attention_attnblock_forward
(
self
,
x
):
h_
=
x
h_
=
self
.
norm
(
h_
)
q1
=
self
.
q
(
h_
)
k1
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
# compute attention
b
,
c
,
h
,
w
=
q1
.
shape
q2
=
q1
.
reshape
(
b
,
c
,
h
*
w
)
del
q1
q
=
q2
.
permute
(
0
,
2
,
1
)
# b,hw,c
del
q2
k
=
k1
.
reshape
(
b
,
c
,
h
*
w
)
# b,c,hw
del
k1
h_
=
torch
.
zeros_like
(
k
,
device
=
q
.
device
)
stats
=
torch
.
cuda
.
memory_stats
(
q
.
device
)
mem_active
=
stats
[
'active_bytes.all.current'
]
mem_reserved
=
stats
[
'reserved_bytes.all.current'
]
mem_free_cuda
,
_
=
torch
.
cuda
.
mem_get_info
(
torch
.
cuda
.
current_device
())
mem_free_torch
=
mem_reserved
-
mem_active
mem_free_total
=
mem_free_cuda
+
mem_free_torch
tensor_size
=
q
.
shape
[
0
]
*
q
.
shape
[
1
]
*
k
.
shape
[
2
]
*
q
.
element_size
()
mem_required
=
tensor_size
*
2.5
steps
=
1
if
mem_required
>
mem_free_total
:
steps
=
2
**
(
math
.
ceil
(
math
.
log
(
mem_required
/
mem_free_total
,
2
)))
slice_size
=
q
.
shape
[
1
]
//
steps
if
(
q
.
shape
[
1
]
%
steps
)
==
0
else
q
.
shape
[
1
]
for
i
in
range
(
0
,
q
.
shape
[
1
],
slice_size
):
end
=
i
+
slice_size
w1
=
torch
.
bmm
(
q
[:,
i
:
end
],
k
)
# b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2
=
w1
*
(
int
(
c
)
**
(
-
0.5
))
del
w1
w3
=
torch
.
nn
.
functional
.
softmax
(
w2
,
dim
=
2
,
dtype
=
q
.
dtype
)
del
w2
# attend to values
v1
=
v
.
reshape
(
b
,
c
,
h
*
w
)
w4
=
w3
.
permute
(
0
,
2
,
1
)
# b,hw,hw (first hw of k, second of q)
del
w3
h_
[:,
:,
i
:
end
]
=
torch
.
bmm
(
v1
,
w4
)
# b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del
v1
,
w4
h2
=
h_
.
reshape
(
b
,
c
,
h
,
w
)
del
h_
h3
=
self
.
proj_out
(
h2
)
del
h2
h3
+=
x
return
h3
class
StableDiffusionModelHijack
:
ids_lookup
=
{}
...
...
@@ -175,6 +245,8 @@ class StableDiffusionModelHijack:
if
cmd_opts
.
opt_split_attention
:
ldm
.
modules
.
attention
.
CrossAttention
.
forward
=
split_cross_attention_forward
ldm
.
modules
.
diffusionmodules
.
model
.
nonlinearity
=
nonlinearity_hijack
ldm
.
modules
.
diffusionmodules
.
model
.
AttnBlock
.
forward
=
cross_attention_attnblock_forward
elif
cmd_opts
.
opt_split_attention_v1
:
ldm
.
modules
.
attention
.
CrossAttention
.
forward
=
split_cross_attention_forward_v1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录