Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
8d8a05a3
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
大约 1 年 前同步成功
通知
1784
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
8d8a05a3
编写于
1月 04, 2023
作者:
A
AUTOMATIC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
find configs for models at runtime rather than when starting
上级
02d7abf5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
14 deletion
+22
-14
modules/sd_hijack_inpainting.py
modules/sd_hijack_inpainting.py
+4
-1
modules/sd_models.py
modules/sd_models.py
+18
-13
未找到文件。
modules/sd_hijack_inpainting.py
浏览文件 @
8d8a05a3
...
...
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
def
should_hijack_inpainting
(
checkpoint_info
):
from
modules
import
sd_models
ckpt_basename
=
os
.
path
.
basename
(
checkpoint_info
.
filename
).
lower
()
cfg_basename
=
os
.
path
.
basename
(
checkpoint_info
.
config
).
lower
()
cfg_basename
=
os
.
path
.
basename
(
sd_models
.
find_checkpoint_config
(
checkpoint_info
)).
lower
()
return
"inpainting"
in
ckpt_basename
and
not
"inpainting"
in
cfg_basename
...
...
modules/sd_models.py
浏览文件 @
8d8a05a3
...
...
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir
=
"Stable-diffusion"
model_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
models_path
,
model_dir
))
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
,
'config'
])
CheckpointInfo
=
namedtuple
(
"CheckpointInfo"
,
[
'filename'
,
'title'
,
'hash'
,
'model_name'
])
checkpoints_list
=
{}
checkpoints_loaded
=
collections
.
OrderedDict
()
...
...
@@ -48,6 +48,14 @@ def checkpoint_tiles():
return
sorted
([
x
.
title
for
x
in
checkpoints_list
.
values
()],
key
=
alphanumeric_key
)
def
find_checkpoint_config
(
info
):
config
=
os
.
path
.
splitext
(
info
.
filename
)[
0
]
+
".yaml"
if
os
.
path
.
exists
(
config
):
return
config
return
shared
.
cmd_opts
.
config
def
list_models
():
checkpoints_list
.
clear
()
model_list
=
modelloader
.
load_models
(
model_path
=
model_path
,
command_path
=
shared
.
cmd_opts
.
ckpt_dir
,
ext_filter
=
[
".ckpt"
,
".safetensors"
])
...
...
@@ -73,7 +81,7 @@ def list_models():
if
os
.
path
.
exists
(
cmd_ckpt
):
h
=
model_hash
(
cmd_ckpt
)
title
,
short_model_name
=
modeltitle
(
cmd_ckpt
,
h
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
cmd_ckpt
,
title
,
h
,
short_model_name
,
shared
.
cmd_opts
.
config
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
cmd_ckpt
,
title
,
h
,
short_model_name
)
shared
.
opts
.
data
[
'sd_model_checkpoint'
]
=
title
elif
cmd_ckpt
is
not
None
and
cmd_ckpt
!=
shared
.
default_sd_model_file
:
print
(
f
"Checkpoint in --ckpt argument not found (Possible it was moved to
{
model_path
}
:
{
cmd_ckpt
}
"
,
file
=
sys
.
stderr
)
...
...
@@ -81,12 +89,7 @@ def list_models():
h
=
model_hash
(
filename
)
title
,
short_model_name
=
modeltitle
(
filename
,
h
)
basename
,
_
=
os
.
path
.
splitext
(
filename
)
config
=
basename
+
".yaml"
if
not
os
.
path
.
exists
(
config
):
config
=
shared
.
cmd_opts
.
config
checkpoints_list
[
title
]
=
CheckpointInfo
(
filename
,
title
,
h
,
short_model_name
,
config
)
checkpoints_list
[
title
]
=
CheckpointInfo
(
filename
,
title
,
h
,
short_model_name
)
def
get_closet_checkpoint_match
(
searchString
):
...
...
@@ -282,9 +285,10 @@ def enable_midas_autodownload():
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_config
=
find_checkpoint_config
(
checkpoint_info
)
if
checkpoint_
info
.
config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from:
{
checkpoint_
info
.
config
}
"
)
if
checkpoint_config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from:
{
checkpoint_config
}
"
)
if
shared
.
sd_model
:
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
...
...
@@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
gc
.
collect
()
devices
.
torch_gc
()
sd_config
=
OmegaConf
.
load
(
checkpoint_
info
.
config
)
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
if
should_hijack_inpainting
(
checkpoint_info
):
# Hardcoded config for now...
...
...
@@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
sd_config
.
model
.
params
.
finetune_keys
=
None
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_
info
.
config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
sd_config
.
model
.
params
.
use_ema
=
False
...
...
@@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
sd_model
=
shared
.
sd_model
current_checkpoint_info
=
sd_model
.
sd_checkpoint_info
checkpoint_config
=
find_checkpoint_config
(
current_checkpoint_info
)
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
if
checkpoint_config
!=
find_checkpoint_config
(
checkpoint_info
)
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
del
sd_model
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录