Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
ddbf4a73
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
12 个月 前同步成功
通知
1771
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
ddbf4a73
编写于
7月 20, 2023
作者:
L
lambertae
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
restart-sampler with correct steps
上级
7bb0fbed
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
14 addition
and
8 deletion
+14
-8
modules/sd_samplers_kdiffusion.py
modules/sd_samplers_kdiffusion.py
+14
-8
未找到文件。
modules/sd_samplers_kdiffusion.py
浏览文件 @
ddbf4a73
...
...
@@ -38,20 +38,19 @@ samplers_k_diffusion = [
def
restart_sampler
(
model
,
x
,
sigmas
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
s_noise
=
1.
):
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)"""
'''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}'''
restart_list
=
{
0.1
:
[
10
,
2
,
2
]}
from
tqdm.auto
import
trange
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
([
x
.
shape
[
0
]])
step_id
=
0
from
k_diffusion.sampling
import
to_d
,
append_zero
def
heun_step
(
x
,
old_sigma
,
new_sigma
):
def
heun_step
(
x
,
old_sigma
,
new_sigma
,
second_order
=
True
):
nonlocal
step_id
denoised
=
model
(
x
,
old_sigma
*
s_in
,
**
extra_args
)
d
=
to_d
(
x
,
old_sigma
,
denoised
)
if
callback
is
not
None
:
callback
({
'x'
:
x
,
'i'
:
step_id
,
'sigma'
:
new_sigma
,
'sigma_hat'
:
old_sigma
,
'denoised'
:
denoised
})
dt
=
new_sigma
-
old_sigma
if
new_sigma
==
0
:
if
new_sigma
==
0
or
not
second_order
:
# Euler method
x
=
x
+
d
*
dt
else
:
...
...
@@ -63,11 +62,6 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
x
=
x
+
d_prime
*
dt
step_id
+=
1
return
x
# print(sigmas)
temp_list
=
dict
()
for
key
,
value
in
restart_list
.
items
():
temp_list
[
int
(
torch
.
argmin
(
abs
(
sigmas
-
key
),
dim
=
0
))]
=
value
restart_list
=
temp_list
def
get_sigmas_karras
(
n
,
sigma_min
,
sigma_max
,
rho
=
7.
,
device
=
'cpu'
):
ramp
=
torch
.
linspace
(
0
,
1
,
n
).
to
(
device
)
min_inv_rho
=
(
sigma_min
**
(
1
/
rho
))
...
...
@@ -78,6 +72,18 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
max_inv_rho
=
max_inv_rho
.
to
(
device
)
sigmas
=
(
max_inv_rho
+
ramp
*
(
min_inv_rho
-
max_inv_rho
))
**
rho
return
append_zero
(
sigmas
).
to
(
device
)
steps
=
sigmas
.
shape
[
0
]
-
1
if
steps
>=
20
:
restart_steps
=
9
restart_times
=
2
if
steps
>=
36
else
1
sigmas
=
get_sigmas_karras
(
steps
-
restart_steps
*
restart_times
,
sigmas
[
-
2
],
sigmas
[
0
],
device
=
sigmas
.
device
)
restart_list
=
{
0.1
:
[
restart_steps
+
1
,
restart_times
,
2
]}
else
:
restart_list
=
dict
()
temp_list
=
dict
()
for
key
,
value
in
restart_list
.
items
():
temp_list
[
int
(
torch
.
argmin
(
abs
(
sigmas
-
key
),
dim
=
0
))]
=
value
restart_list
=
temp_list
for
i
in
trange
(
len
(
sigmas
)
-
1
,
disable
=
disable
):
x
=
heun_step
(
x
,
sigmas
[
i
],
sigmas
[
i
+
1
])
if
i
+
1
in
restart_list
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录