Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
c24a314c
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
11 个月 前同步成功
通知
1759
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c24a314c
编写于
12月 31, 2022
作者:
A
AUTOMATIC1111
提交者:
GitHub
12月 31, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6149 from vladmandic/validate-embeddings
validate textual inversion embeddings
上级
f378b8d5
f55ac33d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
41 addition
and
7 deletion
+41
-7
modules/sd_models.py
modules/sd_models.py
+3
-0
modules/textual_inversion/textual_inversion.py
modules/textual_inversion/textual_inversion.py
+38
-5
modules/ui.py
modules/ui.py
+0
-2
未找到文件。
modules/sd_models.py
浏览文件 @
c24a314c
...
...
@@ -325,6 +325,9 @@ def load_model(checkpoint_info=None):
script_callbacks
.
model_loaded_callback
(
sd_model
)
print
(
"Model loaded."
)
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
(
force_reload
=
True
)
# Reload embeddings after model load as they may or may not fit the model
return
sd_model
...
...
modules/textual_inversion/textual_inversion.py
浏览文件 @
c24a314c
...
...
@@ -23,6 +23,8 @@ class Embedding:
self
.
vec
=
vec
self
.
name
=
name
self
.
step
=
step
self
.
shape
=
None
self
.
vectors
=
0
self
.
cached_checksum
=
None
self
.
sd_checkpoint
=
None
self
.
sd_checkpoint_name
=
None
...
...
@@ -57,8 +59,10 @@ class EmbeddingDatabase:
def
__init__
(
self
,
embeddings_dir
):
self
.
ids_lookup
=
{}
self
.
word_embeddings
=
{}
self
.
skipped_embeddings
=
[]
self
.
dir_mtime
=
None
self
.
embeddings_dir
=
embeddings_dir
self
.
expected_shape
=
-
1
def
register_embedding
(
self
,
embedding
,
model
):
...
...
@@ -75,14 +79,35 @@ class EmbeddingDatabase:
return
embedding
def
load_textual_inversion_embeddings
(
self
):
def
get_expected_shape
(
self
):
expected_shape
=
-
1
# initialize with unknown
idx
=
torch
.
tensor
(
0
).
to
(
shared
.
device
)
if
expected_shape
==
-
1
:
try
:
# matches sd15 signature
first_embedding
=
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
transformer
.
text_model
.
embeddings
.
token_embedding
.
wrapped
(
idx
)
expected_shape
=
first_embedding
.
shape
[
0
]
except
:
pass
if
expected_shape
==
-
1
:
try
:
# matches sd20 signature
first_embedding
=
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
model
.
token_embedding
.
wrapped
(
idx
)
expected_shape
=
first_embedding
.
shape
[
0
]
except
:
pass
if
expected_shape
==
-
1
:
print
(
'Could not determine expected embeddings shape from model'
)
return
expected_shape
def
load_textual_inversion_embeddings
(
self
,
force_reload
=
False
):
mt
=
os
.
path
.
getmtime
(
self
.
embeddings_dir
)
if
self
.
dir_mtime
is
not
None
and
mt
<=
self
.
dir_mtime
:
if
not
force_reload
and
self
.
dir_mtime
is
not
None
and
mt
<=
self
.
dir_mtime
:
return
self
.
dir_mtime
=
mt
self
.
ids_lookup
.
clear
()
self
.
word_embeddings
.
clear
()
self
.
skipped_embeddings
=
[]
self
.
expected_shape
=
self
.
get_expected_shape
()
def
process_file
(
path
,
filename
):
name
=
os
.
path
.
splitext
(
filename
)[
0
]
...
...
@@ -122,7 +147,14 @@ class EmbeddingDatabase:
embedding
.
step
=
data
.
get
(
'step'
,
None
)
embedding
.
sd_checkpoint
=
data
.
get
(
'sd_checkpoint'
,
None
)
embedding
.
sd_checkpoint_name
=
data
.
get
(
'sd_checkpoint_name'
,
None
)
self
.
register_embedding
(
embedding
,
shared
.
sd_model
)
embedding
.
vectors
=
vec
.
shape
[
0
]
embedding
.
shape
=
vec
.
shape
[
-
1
]
if
(
self
.
expected_shape
==
-
1
)
or
(
self
.
expected_shape
==
embedding
.
shape
):
self
.
register_embedding
(
embedding
,
shared
.
sd_model
)
else
:
self
.
skipped_embeddings
.
append
(
name
)
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for
fn
in
os
.
listdir
(
self
.
embeddings_dir
):
try
:
...
...
@@ -137,8 +169,9 @@ class EmbeddingDatabase:
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
continue
print
(
f
"Loaded a total of
{
len
(
self
.
word_embeddings
)
}
textual inversion embeddings."
)
print
(
"Embeddings:"
,
', '
.
join
(
self
.
word_embeddings
.
keys
()))
print
(
"Textual inversion embeddings {num} loaded: {val}"
.
format
(
num
=
len
(
self
.
word_embeddings
),
val
=
', '
.
join
(
self
.
word_embeddings
.
keys
())))
if
(
len
(
self
.
skipped_embeddings
)
>
0
):
print
(
"Textual inversion embeddings {num} skipped: {val}"
.
format
(
num
=
len
(
self
.
skipped_embeddings
),
val
=
', '
.
join
(
self
.
skipped_embeddings
)))
def
find_embedding_at_position
(
self
,
tokens
,
offset
):
token
=
tokens
[
offset
]
...
...
modules/ui.py
浏览文件 @
c24a314c
...
...
@@ -1157,8 +1157,6 @@ def create_ui():
with
gr
.
Column
(
variant
=
'panel'
):
submit_result
=
gr
.
Textbox
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
()
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
train_interface
:
with
gr
.
Row
().
style
(
equal_height
=
False
):
gr
.
HTML
(
value
=
"<p style='margin-bottom: 0.7em'>See <b><a href=
\"
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion
\"
>wiki</a></b> for detailed explanation.</p>"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录