Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
d6fcc6b8
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
大约 1 年 前同步成功
通知
1786
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
d6fcc6b8
编写于
10月 11, 2022
作者:
A
AUTOMATIC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
apply lr schedule to hypernets
上级
12f4f476
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
54 addition
and
45 deletion
+54
-45
modules/hypernetworks/hypernetwork.py
modules/hypernetworks/hypernetwork.py
+15
-4
modules/textual_inversion/learn_schedule.py
modules/textual_inversion/learn_schedule.py
+34
-0
modules/textual_inversion/textual_inversion.py
modules/textual_inversion/textual_inversion.py
+4
-40
modules/ui.py
modules/ui.py
+1
-1
未找到文件。
modules/hypernetworks/hypernetwork.py
浏览文件 @
d6fcc6b8
...
...
@@ -14,6 +14,7 @@ import torch
from
torch
import
einsum
from
einops
import
rearrange
,
repeat
import
modules.textual_inversion.dataset
from
modules.textual_inversion.learn_schedule
import
LearnSchedule
class
HypernetworkModule
(
torch
.
nn
.
Module
):
...
...
@@ -202,8 +203,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
for
weight
in
weights
:
weight
.
requires_grad
=
True
optimizer
=
torch
.
optim
.
AdamW
(
weights
,
lr
=
learn_rate
)
losses
=
torch
.
zeros
((
32
,))
last_saved_file
=
"<none>"
...
...
@@ -213,12 +212,24 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if
ititial_step
>
steps
:
return
hypernetwork
,
filename
schedules
=
iter
(
LearnSchedule
(
learn_rate
,
steps
,
ititial_step
))
(
learn_rate
,
end_step
)
=
next
(
schedules
)
print
(
f
'Training at rate of
{
learn_rate
}
until step
{
end_step
}
'
)
optimizer
=
torch
.
optim
.
AdamW
(
weights
,
lr
=
learn_rate
)
pbar
=
tqdm
.
tqdm
(
enumerate
(
ds
),
total
=
steps
-
ititial_step
)
for
i
,
(
x
,
text
,
cond
)
in
pbar
:
hypernetwork
.
step
=
i
+
ititial_step
if
hypernetwork
.
step
>
steps
:
break
if
hypernetwork
.
step
>
end_step
:
try
:
(
learn_rate
,
end_step
)
=
next
(
schedules
)
except
Exception
:
break
tqdm
.
tqdm
.
write
(
f
'Training at rate of
{
learn_rate
}
until step
{
end_step
}
'
)
for
pg
in
optimizer
.
param_groups
:
pg
[
'lr'
]
=
learn_rate
if
shared
.
state
.
interrupted
:
break
...
...
modules/textual_inversion/learn_schedule.py
0 → 100644
浏览文件 @
d6fcc6b8
class
LearnSchedule
:
def
__init__
(
self
,
learn_rate
,
max_steps
,
cur_step
=
0
):
pairs
=
learn_rate
.
split
(
','
)
self
.
rates
=
[]
self
.
it
=
0
self
.
maxit
=
0
for
i
,
pair
in
enumerate
(
pairs
):
tmp
=
pair
.
split
(
':'
)
if
len
(
tmp
)
==
2
:
step
=
int
(
tmp
[
1
])
if
step
>
cur_step
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
min
(
step
,
max_steps
)))
self
.
maxit
+=
1
if
step
>
max_steps
:
return
elif
step
==
-
1
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
else
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
if
self
.
it
<
self
.
maxit
:
self
.
it
+=
1
return
self
.
rates
[
self
.
it
-
1
]
else
:
raise
StopIteration
modules/textual_inversion/textual_inversion.py
浏览文件 @
d6fcc6b8
...
...
@@ -10,6 +10,7 @@ import datetime
from
modules
import
shared
,
devices
,
sd_hijack
,
processing
,
sd_models
import
modules.textual_inversion.dataset
from
modules.textual_inversion.learn_schedule
import
LearnSchedule
class
Embedding
:
...
...
@@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if
ititial_step
>
steps
:
return
embedding
,
filename
tr_img_len
=
len
([
os
.
path
.
join
(
data_root
,
file_path
)
for
file_path
in
os
.
listdir
(
data_root
)])
epoch_len
=
(
tr_img_len
*
num_repeats
)
+
tr_img_len
scheduleIter
=
iter
(
LearnSchedule
(
learn_rate
,
steps
,
ititial_step
))
(
learn_rate
,
end_step
)
=
next
(
scheduleIter
)
schedules
=
iter
(
LearnSchedule
(
learn_rate
,
steps
,
ititial_step
))
(
learn_rate
,
end_step
)
=
next
(
schedules
)
print
(
f
'Training at rate of
{
learn_rate
}
until step
{
end_step
}
'
)
optimizer
=
torch
.
optim
.
AdamW
([
embedding
.
vec
],
lr
=
learn_rate
)
...
...
@@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if
embedding
.
step
>
end_step
:
try
:
(
learn_rate
,
end_step
)
=
next
(
schedule
Iter
)
(
learn_rate
,
end_step
)
=
next
(
schedule
s
)
except
:
break
tqdm
.
tqdm
.
write
(
f
'Training at rate of
{
learn_rate
}
until step
{
end_step
}
'
)
...
...
@@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding
.
save
(
filename
)
return
embedding
,
filename
class
LearnSchedule
:
def
__init__
(
self
,
learn_rate
,
max_steps
,
cur_step
=
0
):
pairs
=
learn_rate
.
split
(
','
)
self
.
rates
=
[]
self
.
it
=
0
self
.
maxit
=
0
for
i
,
pair
in
enumerate
(
pairs
):
tmp
=
pair
.
split
(
':'
)
if
len
(
tmp
)
==
2
:
step
=
int
(
tmp
[
1
])
if
step
>
cur_step
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
min
(
step
,
max_steps
)))
self
.
maxit
+=
1
if
step
>
max_steps
:
return
elif
step
==
-
1
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
else
:
self
.
rates
.
append
((
float
(
tmp
[
0
]),
max_steps
))
self
.
maxit
+=
1
return
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
if
self
.
it
<
self
.
maxit
:
self
.
it
+=
1
return
self
.
rates
[
self
.
it
-
1
]
else
:
raise
StopIteration
modules/ui.py
浏览文件 @
d6fcc6b8
...
...
@@ -1070,7 +1070,7 @@ def create_ui(wrap_gradio_gpu_call):
gr
.
HTML
(
value
=
"<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>"
)
train_embedding_name
=
gr
.
Dropdown
(
label
=
'Embedding'
,
choices
=
sorted
(
sd_hijack
.
model_hijack
.
embedding_db
.
word_embeddings
.
keys
()))
train_hypernetwork_name
=
gr
.
Dropdown
(
label
=
'Hypernetwork'
,
choices
=
[
x
for
x
in
shared
.
hypernetworks
.
keys
()])
learn_rate
=
gr
.
Textbox
(
label
=
'Learning rate'
,
placeholder
=
"Learning rate"
,
value
=
"5.0e-03
"
)
learn_rate
=
gr
.
Textbox
(
label
=
'Learning rate'
,
placeholder
=
"Learning rate"
,
value
=
"0.005
"
)
dataset_directory
=
gr
.
Textbox
(
label
=
'Dataset directory'
,
placeholder
=
"Path to directory with input images"
)
log_directory
=
gr
.
Textbox
(
label
=
'Log directory'
,
placeholder
=
"Path to directory where to write outputs"
,
value
=
"textual_inversion"
)
template_file
=
gr
.
Textbox
(
label
=
'Prompt template file'
,
value
=
os
.
path
.
join
(
script_path
,
"textual_inversion_templates"
,
"style_filewords.txt"
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录