Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
c4bfd20f
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
10 个月 前同步成功
通知
1749
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
c4bfd20f
编写于
1月 12, 2023
作者:
S
Shondoit
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Hijack to add weighted_forward to model: return loss * weight map
上级
3715ece0
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
52 addition
and
0 deletion
+52
-0
modules/sd_hijack.py
modules/sd_hijack.py
+52
-0
未找到文件。
modules/sd_hijack.py
浏览文件 @
c4bfd20f
import
torch
from
torch.nn.functional
import
silu
from
types
import
MethodType
import
modules.textual_inversion.textual_inversion
from
modules
import
devices
,
sd_hijack_optimizations
,
shared
,
sd_hijack_checkpoint
...
...
@@ -76,6 +77,54 @@ def fix_checkpoint():
pass
def
weighted_loss
(
sd_model
,
pred
,
target
,
mean
=
True
):
#Calculate the weight normally, but ignore the mean
loss
=
sd_model
.
_old_get_loss
(
pred
,
target
,
mean
=
False
)
#Check if we have weights available
weight
=
getattr
(
sd_model
,
'_custom_loss_weight'
,
None
)
if
weight
is
not
None
:
loss
*=
weight
#Return the loss, as mean if specified
return
loss
.
mean
()
if
mean
else
loss
def
weighted_forward
(
sd_model
,
x
,
c
,
w
,
*
args
,
**
kwargs
):
try
:
#Temporarily append weights to a place accessible during loss calc
sd_model
.
_custom_loss_weight
=
w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if
not
hasattr
(
sd_model
,
'_old_get_loss'
):
sd_model
.
_old_get_loss
=
sd_model
.
get_loss
sd_model
.
get_loss
=
MethodType
(
weighted_loss
,
sd_model
)
#Run the standard forward function, but with the patched 'get_loss'
return
sd_model
.
forward
(
x
,
c
,
*
args
,
**
kwargs
)
finally
:
try
:
#Delete temporary weights if appended
del
sd_model
.
_custom_loss_weight
except
AttributeError
as
e
:
pass
#If we have an old loss function, reset the loss function to the original one
if
hasattr
(
sd_model
,
'_old_get_loss'
):
sd_model
.
get_loss
=
sd_model
.
_old_get_loss
del
sd_model
.
_old_get_loss
def
apply_weighted_forward
(
sd_model
):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model
.
weighted_forward
=
MethodType
(
weighted_forward
,
sd_model
)
def
undo_weighted_forward
(
sd_model
):
try
:
del
sd_model
.
weighted_forward
except
AttributeError
as
e
:
pass
class
StableDiffusionModelHijack
:
fixes
=
None
comments
=
[]
...
...
@@ -104,6 +153,8 @@ class StableDiffusionModelHijack:
m
.
cond_stage_model
.
model
.
token_embedding
=
EmbeddingsWithFixes
(
m
.
cond_stage_model
.
model
.
token_embedding
,
self
)
m
.
cond_stage_model
=
sd_hijack_open_clip
.
FrozenOpenCLIPEmbedderWithCustomWords
(
m
.
cond_stage_model
,
self
)
apply_weighted_forward
(
m
)
self
.
optimization_method
=
apply_optimizations
()
self
.
clip
=
m
.
cond_stage_model
...
...
@@ -132,6 +183,7 @@ class StableDiffusionModelHijack:
m
.
cond_stage_model
=
m
.
cond_stage_model
.
wrapped
undo_optimizations
()
undo_weighted_forward
(
m
)
self
.
apply_circular
(
False
)
self
.
layers
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录