Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
c1512ef9
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
11 个月 前同步成功
通知
1759
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
c1512ef9
编写于
12月 25, 2022
作者:
A
AUTOMATIC1111
提交者:
GitHub
12月 25, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5999 from vladmandic/trainapi
implement train api
上级
8eef9d8e
5f1dfbbc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
132 addition
and
28 deletion
+132
-28
modules/api/api.py
modules/api/api.py
+93
-1
modules/api/models.py
modules/api/models.py
+9
-0
modules/hypernetworks/hypernetwork.py
modules/hypernetworks/hypernetwork.py
+26
-0
modules/hypernetworks/ui.py
modules/hypernetworks/ui.py
+4
-27
未找到文件。
modules/api/api.py
浏览文件 @
c1512ef9
...
...
@@ -10,13 +10,17 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from
secrets
import
compare_digest
import
modules.shared
as
shared
from
modules
import
sd_samplers
,
deepbooru
from
modules
import
sd_samplers
,
deepbooru
,
sd_hijack
from
modules.api.models
import
*
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.extras
import
run_extras
,
run_pnginfo
from
modules.textual_inversion.textual_inversion
import
create_embedding
,
train_embedding
from
modules.textual_inversion.preprocess
import
preprocess
from
modules.hypernetworks.hypernetwork
import
create_hypernetwork
,
train_hypernetwork
from
PIL
import
PngImagePlugin
,
Image
from
modules.sd_models
import
checkpoints_list
from
modules.realesrgan_model
import
get_realesrgan_models
from
modules
import
devices
from
typing
import
List
def
upscaler_to_index
(
name
:
str
):
...
...
@@ -97,6 +101,11 @@ class Api:
self
.
add_api_route
(
"/sdapi/v1/artist-categories"
,
self
.
get_artists_categories
,
methods
=
[
"GET"
],
response_model
=
List
[
str
])
self
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
self
.
add_api_route
(
"/sdapi/v1/refresh-checkpoints"
,
self
.
refresh_checkpoints
,
methods
=
[
"POST"
])
self
.
add_api_route
(
"/sdapi/v1/create/embedding"
,
self
.
create_embedding
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/create/hypernetwork"
,
self
.
create_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
CreateResponse
)
self
.
add_api_route
(
"/sdapi/v1/preprocess"
,
self
.
preprocess
,
methods
=
[
"POST"
],
response_model
=
PreprocessResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/embedding"
,
self
.
train_embedding
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
self
.
add_api_route
(
"/sdapi/v1/train/hypernetwork"
,
self
.
train_hypernetwork
,
methods
=
[
"POST"
],
response_model
=
TrainResponse
)
def
add_api_route
(
self
,
path
:
str
,
endpoint
,
**
kwargs
):
if
shared
.
cmd_opts
.
api_auth
:
...
...
@@ -326,6 +335,89 @@ class Api:
def
refresh_checkpoints
(
self
):
shared
.
refresh_checkpoints
()
def
create_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_embedding
(
**
args
)
# create empty embedding
sd_hijack
.
model_hijack
.
embedding_db
.
load_textual_inversion_embeddings
()
# reload embeddings so new one can be immediately used
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create embedding filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create embedding error: {error}"
.
format
(
error
=
e
))
def
create_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
filename
=
create_hypernetwork
(
**
args
)
# create empty embedding
shared
.
state
.
end
()
return
CreateResponse
(
info
=
"create hypernetwork filename: {filename}"
.
format
(
filename
=
filename
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"create hypernetwork error: {error}"
.
format
(
error
=
e
))
def
preprocess
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
preprocess
(
**
args
)
# quick operation unless blip/booru interrogation is enabled
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess complete'
)
except
KeyError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: invalid token: {error}"
.
format
(
error
=
e
))
except
AssertionError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
"preprocess error: {error}"
.
format
(
error
=
e
))
except
FileNotFoundError
as
e
:
shared
.
state
.
end
()
return
PreprocessResponse
(
info
=
'preprocess error: {error}'
.
format
(
error
=
e
))
def
train_embedding
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
embedding
,
filename
=
train_embedding
(
**
args
)
# can take a long time to complete
except
Exception
as
e
:
error
=
e
finally
:
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {msg}"
.
format
(
msg
=
msg
))
def
train_hypernetwork
(
self
,
args
:
dict
):
try
:
shared
.
state
.
begin
()
initial_hypernetwork
=
shared
.
loaded_hypernetwork
apply_optimizations
=
shared
.
opts
.
training_xattention_optimizations
error
=
None
filename
=
''
if
not
apply_optimizations
:
sd_hijack
.
undo_optimizations
()
try
:
hypernetwork
,
filename
=
train_hypernetwork
(
*
args
)
except
Exception
as
e
:
error
=
e
finally
:
shared
.
loaded_hypernetwork
=
initial_hypernetwork
shared
.
sd_model
.
cond_stage_model
.
to
(
devices
.
device
)
shared
.
sd_model
.
first_stage_model
.
to
(
devices
.
device
)
if
not
apply_optimizations
:
sd_hijack
.
apply_optimizations
()
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding complete: filename: {filename} error: {error}"
.
format
(
filename
=
filename
,
error
=
error
))
except
AssertionError
as
msg
:
shared
.
state
.
end
()
return
TrainResponse
(
info
=
"train embedding error: {error}"
.
format
(
error
=
error
))
def
launch
(
self
,
server_name
,
port
):
self
.
app
.
include_router
(
self
.
router
)
uvicorn
.
run
(
self
.
app
,
host
=
server_name
,
port
=
port
)
modules/api/models.py
浏览文件 @
c1512ef9
...
...
@@ -175,6 +175,15 @@ class InterrogateRequest(BaseModel):
class
InterrogateResponse
(
BaseModel
):
caption
:
str
=
Field
(
default
=
None
,
title
=
"Caption"
,
description
=
"The generated caption for the image."
)
class
TrainResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Train info"
,
description
=
"Response string from train embedding or hypernetwork task."
)
class
CreateResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Create info"
,
description
=
"Response string from create embedding or hypernetwork task."
)
class
PreprocessResponse
(
BaseModel
):
info
:
str
=
Field
(
title
=
"Preprocess info"
,
description
=
"Response string from preprocessing task."
)
fields
=
{}
for
key
,
metadata
in
opts
.
data_labels
.
items
():
value
=
opts
.
data
.
get
(
key
)
...
...
modules/hypernetworks/hypernetwork.py
浏览文件 @
c1512ef9
...
...
@@ -378,6 +378,32 @@ def report_statistics(loss_info:dict):
print
(
e
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"
{
name
}
.pt"
)
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file
{
fn
}
already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
fn
def
train_hypernetwork
(
hypernetwork_name
,
learn_rate
,
batch_size
,
gradient_step
,
data_root
,
log_directory
,
training_width
,
training_height
,
steps
,
shuffle_tags
,
tag_drop_out
,
latent_sampling_method
,
create_image_every
,
save_hypernetwork_every
,
template_file
,
preview_from_txt2img
,
preview_prompt
,
preview_negative_prompt
,
preview_steps
,
preview_sampler_index
,
preview_cfg_scale
,
preview_seed
,
preview_width
,
preview_height
):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
...
...
modules/hypernetworks/ui.py
浏览文件 @
c1512ef9
...
...
@@ -3,39 +3,16 @@ import os
import
re
import
gradio
as
gr
import
modules.textual_inversion.preprocess
import
modules.textual_inversion.textual_inversion
import
modules.hypernetworks.hypernetwork
from
modules
import
devices
,
sd_hijack
,
shared
from
modules.hypernetworks
import
hypernetwork
not_available
=
[
"hardswish"
,
"multiheadattention"
]
keys
=
list
(
x
for
x
in
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
keys
=
list
(
x
for
x
in
modules
.
hypernetworks
.
hypernetwork
.
HypernetworkModule
.
activation_dict
.
keys
()
if
x
not
in
not_available
)
def
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
=
None
,
activation_func
=
None
,
weight_init
=
None
,
add_layer_norm
=
False
,
use_dropout
=
False
):
# Remove illegal characters from name.
name
=
""
.
join
(
x
for
x
in
name
if
(
x
.
isalnum
()
or
x
in
"._- "
))
filename
=
modules
.
hypernetworks
.
hypernetwork
.
create_hypernetwork
(
name
,
enable_sizes
,
overwrite_old
,
layer_structure
,
activation_func
,
weight_init
,
add_layer_norm
,
use_dropout
)
fn
=
os
.
path
.
join
(
shared
.
cmd_opts
.
hypernetwork_dir
,
f
"
{
name
}
.pt"
)
if
not
overwrite_old
:
assert
not
os
.
path
.
exists
(
fn
),
f
"file
{
fn
}
already exists"
if
type
(
layer_structure
)
==
str
:
layer_structure
=
[
float
(
x
.
strip
())
for
x
in
layer_structure
.
split
(
","
)]
hypernet
=
modules
.
hypernetworks
.
hypernetwork
.
Hypernetwork
(
name
=
name
,
enable_sizes
=
[
int
(
x
)
for
x
in
enable_sizes
],
layer_structure
=
layer_structure
,
activation_func
=
activation_func
,
weight_init
=
weight_init
,
add_layer_norm
=
add_layer_norm
,
use_dropout
=
use_dropout
,
)
hypernet
.
save
(
fn
)
shared
.
reload_hypernetworks
()
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created:
{
fn
}
"
,
""
return
gr
.
Dropdown
.
update
(
choices
=
sorted
([
x
for
x
in
shared
.
hypernetworks
.
keys
()])),
f
"Created:
{
filename
}
"
,
""
def
train_hypernetwork
(
*
args
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录