Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
07be13ca
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1748
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
07be13ca
编写于
8月 01, 2023
作者:
A
AUTOMATIC1111
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add metadata to checkpoint merger
上级
6d3a0c95
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
52 addition
and
9 deletion
+52
-9
modules/extras.py
modules/extras.py
+33
-6
modules/sd_models.py
modules/sd_models.py
+1
-1
modules/ui_checkpoint_merger.py
modules/ui_checkpoint_merger.py
+18
-2
未找到文件。
modules/extras.py
浏览文件 @
07be13ca
...
...
@@ -7,7 +7,7 @@ import json
import
torch
import
tqdm
from
modules
import
shared
,
images
,
sd_models
,
sd_vae
,
sd_models_config
from
modules
import
shared
,
images
,
sd_models
,
sd_vae
,
sd_models_config
,
errors
from
modules.ui_common
import
plaintext_to_html
import
gradio
as
gr
import
safetensors.torch
...
...
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return
tensor
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
,
save_metadata
):
def
read_metadata
(
primary_model_name
,
secondary_model_name
,
tertiary_model_name
):
metadata
=
{}
for
checkpoint_name
in
[
primary_model_name
,
secondary_model_name
,
tertiary_model_name
]:
checkpoint_info
=
sd_models
.
checkpoints_list
.
get
(
checkpoint_name
,
None
)
if
checkpoint_info
is
None
:
continue
metadata
.
update
(
checkpoint_info
.
metadata
)
return
json
.
dumps
(
metadata
,
indent
=
4
,
ensure_ascii
=
False
)
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
,
save_metadata
,
add_merge_recipe
,
copy_metadata_fields
,
metadata_json
):
shared
.
state
.
begin
(
job
=
"model-merge"
)
def
fail
(
message
):
...
...
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared
.
state
.
textinfo
=
"Saving"
print
(
f
"Saving to
{
output_modelname
}
..."
)
metadata
=
None
metadata
=
{}
if
save_metadata
and
copy_metadata_fields
:
if
primary_model_info
:
metadata
.
update
(
primary_model_info
.
metadata
)
if
secondary_model_info
:
metadata
.
update
(
secondary_model_info
.
metadata
)
if
tertiary_model_info
:
metadata
.
update
(
tertiary_model_info
.
metadata
)
if
save_metadata
:
metadata
=
{
"format"
:
"pt"
}
try
:
metadata
.
update
(
json
.
loads
(
metadata_json
))
except
Exception
as
e
:
errors
.
display
(
e
,
"readin metadata from json"
)
metadata
[
"format"
]
=
"pt"
if
save_metadata
and
add_merge_recipe
:
merge_recipe
=
{
"type"
:
"webui"
,
# indicate this model was merged with webui's built-in merger
"primary_model_hash"
:
primary_model_info
.
sha256
,
...
...
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting"
:
result_is_inpainting_model
,
"is_instruct_pix2pix"
:
result_is_instruct_pix2pix_model
}
metadata
[
"sd_merge_recipe"
]
=
json
.
dumps
(
merge_recipe
)
sd_merge_models
=
{}
...
...
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if
tertiary_model_info
:
add_model_metadata
(
tertiary_model_info
)
metadata
[
"sd_merge_recipe"
]
=
json
.
dumps
(
merge_recipe
)
metadata
[
"sd_merge_models"
]
=
json
.
dumps
(
sd_merge_models
)
_
,
extension
=
os
.
path
.
splitext
(
output_modelname
)
if
extension
.
lower
()
==
".safetensors"
:
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
metadata
)
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
metadata
if
len
(
metadata
)
>
0
else
None
)
else
:
torch
.
save
(
theta_0
,
output_modelname
)
...
...
modules/sd_models.py
浏览文件 @
07be13ca
...
...
@@ -85,7 +85,7 @@ class CheckpointInfo:
if
self
.
shorthash
not
in
self
.
ids
:
self
.
ids
+=
[
self
.
shorthash
,
self
.
sha256
,
f
'
{
self
.
name
}
[
{
self
.
shorthash
}
]'
]
checkpoints_list
.
pop
(
self
.
title
)
checkpoints_list
.
pop
(
self
.
title
,
None
)
self
.
title
=
f
'
{
self
.
name
}
[
{
self
.
shorthash
}
]'
self
.
register
()
...
...
modules/ui_checkpoint_merger.py
浏览文件 @
07be13ca
...
...
@@ -51,7 +51,6 @@ class UiCheckpointMerger:
with
FormRow
():
self
.
checkpoint_format
=
gr
.
Radio
(
choices
=
[
"ckpt"
,
"safetensors"
],
value
=
"safetensors"
,
label
=
"Checkpoint format"
,
elem_id
=
"modelmerger_checkpoint_format"
)
self
.
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
,
elem_id
=
"modelmerger_save_as_half"
)
self
.
save_metadata
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Save metadata (.safetensors only)"
,
elem_id
=
"modelmerger_save_metadata"
)
with
FormRow
():
with
gr
.
Column
():
...
...
@@ -65,16 +64,30 @@ class UiCheckpointMerger:
with
FormRow
():
self
.
discard_weights
=
gr
.
Textbox
(
value
=
""
,
label
=
"Discard weights with matching name"
,
elem_id
=
"modelmerger_discard_weights"
)
with
gr
.
Row
():
with
gr
.
Accordion
(
"Metadata"
,
open
=
False
)
as
metadata_editor
:
with
FormRow
():
self
.
save_metadata
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Save metadata"
,
elem_id
=
"modelmerger_save_metadata"
)
self
.
add_merge_recipe
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Add merge recipe metadata"
,
elem_id
=
"modelmerger_add_recipe"
)
self
.
copy_metadata_fields
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Copy metadata from merged models"
,
elem_id
=
"modelmerger_copy_metadata"
)
self
.
metadata_json
=
gr
.
TextArea
(
'{}'
,
label
=
"Metadata in JSON format"
)
self
.
read_metadata
=
gr
.
Button
(
"Read metadata from selected checkpoints"
)
with
FormRow
():
self
.
modelmerger_merge
=
gr
.
Button
(
elem_id
=
"modelmerger_merge"
,
value
=
"Merge"
,
variant
=
'primary'
)
with
gr
.
Column
(
variant
=
'compact'
,
elem_id
=
"modelmerger_results_container"
):
with
gr
.
Group
(
elem_id
=
"modelmerger_results_panel"
):
self
.
modelmerger_result
=
gr
.
HTML
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
self
.
metadata_editor
=
metadata_editor
self
.
blocks
=
modelmerger_interface
def
setup_ui
(
self
,
dummy_component
,
sd_model_checkpoint_component
):
self
.
checkpoint_format
.
change
(
lambda
fmt
:
gr
.
update
(
visible
=
fmt
==
'safetensors'
),
inputs
=
[
self
.
checkpoint_format
],
outputs
=
[
self
.
metadata_editor
],
show_progress
=
False
)
self
.
read_metadata
.
click
(
extras
.
read_metadata
,
inputs
=
[
self
.
primary_model_name
,
self
.
secondary_model_name
,
self
.
tertiary_model_name
],
outputs
=
[
self
.
metadata_json
])
self
.
modelmerger_merge
.
click
(
fn
=
lambda
:
''
,
inputs
=
[],
outputs
=
[
self
.
modelmerger_result
])
self
.
modelmerger_merge
.
click
(
fn
=
call_queue
.
wrap_gradio_gpu_call
(
modelmerger
,
extra_outputs
=
lambda
:
[
gr
.
update
()
for
_
in
range
(
4
)]),
...
...
@@ -93,6 +106,9 @@ class UiCheckpointMerger:
self
.
bake_in_vae
,
self
.
discard_weights
,
self
.
save_metadata
,
self
.
add_merge_recipe
,
self
.
copy_metadata_fields
,
self
.
metadata_json
,
],
outputs
=
[
self
.
primary_model_name
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录