Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
cb31abcf
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1748
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
cb31abcf
编写于
10月 30, 2022
作者:
M
Muhammad Rizqi Nur
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Settings to select VAE
上级
17a2076f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
141 addition
and
24 deletion
+141
-24
modules/sd_models.py
modules/sd_models.py
+10
-21
modules/sd_vae.py
modules/sd_vae.py
+121
-0
modules/shared.py
modules/shared.py
+5
-3
webui.py
webui.py
+5
-0
未找到文件。
modules/sd_models.py
浏览文件 @
cb31abcf
...
...
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
from
ldm.util
import
instantiate_from_config
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
from
modules
import
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
from
modules.paths
import
models_path
from
modules.sd_hijack_inpainting
import
do_inpainting_hijack
,
should_hijack_inpainting
...
...
@@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd):
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
def
load_model_weights
(
model
,
checkpoint_info
):
def
load_model_weights
(
model
,
checkpoint_info
,
force
=
False
):
checkpoint_file
=
checkpoint_info
.
filename
sd_model_hash
=
checkpoint_info
.
hash
if
checkpoint_info
not
in
checkpoints_loaded
:
if
force
or
checkpoint_info
not
in
checkpoints_loaded
:
print
(
f
"Loading weights [
{
sd_model_hash
}
] from
{
checkpoint_file
}
"
)
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
...
...
@@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info):
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
vae_file
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
+
".vae.pt"
if
not
os
.
path
.
exists
(
vae_file
)
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
vae_file
=
shared
.
cmd_opts
.
vae_path
if
os
.
path
.
exists
(
vae_file
):
print
(
f
"Loading VAE weights from:
{
vae_file
}
"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
].
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
model
.
first_stage_model
.
load_state_dict
(
vae_dict
)
sd_vae
.
load_vae
(
model
,
checkpoint_file
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
if
shared
.
opts
.
sd_checkpoint_cache
>
0
:
...
...
@@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info):
model
.
sd_checkpoint_info
=
checkpoint_info
def
load_model
(
checkpoint_info
=
None
):
def
load_model
(
checkpoint_info
=
None
,
force
=
False
):
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
...
...
@@ -234,7 +223,7 @@ def load_model(checkpoint_info=None):
do_inpainting_hijack
()
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
load_model_weights
(
sd_model
,
checkpoint_info
)
load_model_weights
(
sd_model
,
checkpoint_info
,
force
=
force
)
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
...
...
@@ -252,16 +241,16 @@ def load_model(checkpoint_info=None):
return
sd_model
def
reload_model_weights
(
sd_model
,
info
=
None
):
def
reload_model_weights
(
sd_model
,
info
=
None
,
force
=
False
):
from
modules
import
lowvram
,
devices
,
sd_hijack
checkpoint_info
=
info
or
select_checkpoint
()
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
and
not
force
:
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
load_model
(
checkpoint_info
,
force
=
force
)
return
shared
.
sd_model
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
...
...
@@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None):
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
load_model_weights
(
sd_model
,
checkpoint_info
)
load_model_weights
(
sd_model
,
checkpoint_info
,
force
=
force
)
sd_hijack
.
model_hijack
.
hijack
(
sd_model
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
...
...
modules/sd_vae.py
0 → 100644
浏览文件 @
cb31abcf
import
torch
import
os
from
collections
import
namedtuple
from
modules
import
shared
,
devices
from
modules.paths
import
models_path
import
glob
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
vae_dir
=
"VAE"
vae_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
vae_dir
))
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
default_vae_dict
=
{
"auto"
:
"auto"
,
"None"
:
"None"
}
default_vae_list
=
[
"auto"
,
"None"
]
default_vae_values
=
[
default_vae_dict
[
x
]
for
x
in
default_vae_list
]
vae_dict
=
dict
(
default_vae_dict
)
vae_list
=
list
(
default_vae_list
)
first_load
=
True
def
get_filename
(
filepath
):
return
os
.
path
.
splitext
(
os
.
path
.
basename
(
filepath
))[
0
]
def
refresh_vae_list
(
vae_path
=
vae_path
,
model_path
=
model_path
):
global
vae_dict
,
vae_list
res
=
{}
candidates
=
[
*
glob
.
iglob
(
os
.
path
.
join
(
model_path
,
'**/*.vae.pt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
model_path
,
'**/*.vae.ckpt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
vae_path
,
'**/*.pt'
),
recursive
=
True
),
*
glob
.
iglob
(
os
.
path
.
join
(
vae_path
,
'**/*.ckpt'
),
recursive
=
True
)
]
if
shared
.
cmd_opts
.
vae_path
is
not
None
and
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
candidates
.
append
(
shared
.
cmd_opts
.
vae_path
)
for
filepath
in
candidates
:
name
=
get_filename
(
filepath
)
res
[
name
]
=
filepath
vae_list
.
clear
()
vae_list
.
extend
(
default_vae_list
)
vae_list
.
extend
(
list
(
res
.
keys
()))
vae_dict
.
clear
()
vae_dict
.
update
(
default_vae_dict
)
vae_dict
.
update
(
res
)
return
vae_list
def
load_vae
(
model
,
checkpoint_file
,
vae_file
=
"auto"
):
global
first_load
,
vae_dict
,
vae_list
# save_settings = False
# if vae_file argument is provided, it takes priority
if
vae_file
and
vae_file
not
in
default_vae_list
:
if
not
os
.
path
.
isfile
(
vae_file
):
vae_file
=
"auto"
# save_settings = True
print
(
"VAE provided as function argument doesn't exist"
)
# for the first load, if vae-path is provided, it takes priority and failure is reported
if
first_load
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
# save_settings = True
# print("Using VAE provided as command line argument")
else
:
print
(
"VAE provided as command line argument doesn't exist"
)
# else, we load from settings
if
vae_file
==
"auto"
and
shared
.
opts
.
sd_vae
is
not
None
:
# if saved VAE settings isn't recognized, fallback to auto
vae_file
=
vae_dict
.
get
(
shared
.
opts
.
sd_vae
,
"auto"
)
# if VAE selected but not found, fallback to auto
if
vae_file
not
in
default_vae_values
and
not
os
.
path
.
isfile
(
vae_file
):
vae_file
=
"auto"
print
(
"Selected VAE doesn't exist"
)
# vae-path cmd arg takes priority for auto
if
vae_file
==
"auto"
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
print
(
"Using VAE provided as command line argument"
)
# if still not found, try look for ".vae.pt" beside model
model_path
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.pt"
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model"
)
# if still not found, try look for ".vae.ckpt" beside model
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.ckpt"
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model"
)
# No more fallbacks for auto
if
vae_file
==
"auto"
:
vae_file
=
None
# Last check, just because
if
vae_file
and
not
os
.
path
.
exists
(
vae_file
):
vae_file
=
None
if
vae_file
:
print
(
f
"Loading VAE weights from:
{
vae_file
}
"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_dict_1
=
{
k
:
v
for
k
,
v
in
vae_ckpt
[
"state_dict"
].
items
()
if
k
[
0
:
4
]
!=
"loss"
and
k
not
in
vae_ignore_keys
}
model
.
first_stage_model
.
load_state_dict
(
vae_dict_1
)
# If vae used is not in dict, update it
# It will be removed on refresh though
if
vae_file
is
not
None
:
vae_opt
=
get_filename
(
vae_file
)
if
vae_opt
not
in
vae_dict
:
vae_dict
[
vae_opt
]
=
vae_file
vae_list
.
append
(
vae_opt
)
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load
=
False
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
modules/shared.py
浏览文件 @
cb31abcf
...
...
@@ -14,7 +14,7 @@ import modules.memmon
import
modules.sd_models
import
modules.styles
import
modules.devices
as
devices
from
modules
import
sd_samplers
,
sd_models
,
localization
from
modules
import
sd_samplers
,
sd_models
,
localization
,
sd_vae
from
modules.hypernetworks
import
hypernetwork
from
modules.paths
import
models_path
,
script_path
,
sd_path
...
...
@@ -295,6 +295,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates
.
update
(
options_section
((
'sd'
,
"Stable Diffusion"
),
{
"sd_model_checkpoint"
:
OptionInfo
(
None
,
"Stable Diffusion checkpoint"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
modules
.
sd_models
.
checkpoint_tiles
()},
refresh
=
sd_models
.
list_models
),
"sd_checkpoint_cache"
:
OptionInfo
(
0
,
"Checkpoints to cache in RAM"
,
gr
.
Slider
,
{
"minimum"
:
0
,
"maximum"
:
10
,
"step"
:
1
}),
"sd_vae"
:
OptionInfo
(
"auto"
,
"SD VAE"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
list
(
sd_vae
.
vae_list
)},
refresh
=
sd_vae
.
refresh_vae_list
),
"sd_hypernetwork"
:
OptionInfo
(
"None"
,
"Hypernetwork"
,
gr
.
Dropdown
,
lambda
:
{
"choices"
:
[
"None"
]
+
[
x
for
x
in
hypernetworks
.
keys
()]},
refresh
=
reload_hypernetworks
),
"sd_hypernetwork_strength"
:
OptionInfo
(
1.0
,
"Hypernetwork strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.001
}),
"inpainting_mask_weight"
:
OptionInfo
(
1.0
,
"Inpainting conditioning mask strength"
,
gr
.
Slider
,
{
"minimum"
:
0.0
,
"maximum"
:
1.0
,
"step"
:
0.01
}),
...
...
@@ -407,11 +408,12 @@ class Options:
if
bad_settings
>
0
:
print
(
f
"The program is likely to not work with bad settings.
\n
Settings file:
{
filename
}
\n
Either fix the file, or delete it and restart."
,
file
=
sys
.
stderr
)
def
onchange
(
self
,
key
,
func
):
def
onchange
(
self
,
key
,
func
,
call
=
True
):
item
=
self
.
data_labels
.
get
(
key
)
item
.
onchange
=
func
func
()
if
call
:
func
()
def
dumpjson
(
self
):
d
=
{
k
:
self
.
data
.
get
(
k
,
self
.
data_labels
.
get
(
k
).
default
)
for
k
in
self
.
data_labels
.
keys
()}
...
...
webui.py
浏览文件 @
cb31abcf
...
...
@@ -21,6 +21,7 @@ import modules.paths
import
modules.scripts
import
modules.sd_hijack
import
modules.sd_models
import
modules.sd_vae
import
modules.shared
as
shared
import
modules.txt2img
...
...
@@ -74,8 +75,12 @@ def initialize():
modules
.
scripts
.
load_scripts
()
modules
.
sd_vae
.
refresh_vae_list
()
modules
.
sd_models
.
load_model
()
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
(
shared
.
sd_model
)))
# I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram,
# so for now this reloads the whole model too, and no cache
shared
.
opts
.
onchange
(
"sd_vae"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
(
shared
.
sd_model
,
force
=
True
)),
call
=
False
)
shared
.
opts
.
onchange
(
"sd_hypernetwork"
,
wrap_queued_call
(
lambda
:
modules
.
hypernetworks
.
hypernetwork
.
load_hypernetwork
(
shared
.
opts
.
sd_hypernetwork
)))
shared
.
opts
.
onchange
(
"sd_hypernetwork_strength"
,
modules
.
hypernetworks
.
hypernetwork
.
apply_strength
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录