Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
a70dfb64
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
大约 1 年 前同步成功
通知
1785
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
a70dfb64
编写于
12月 31, 2023
作者:
A
AUTOMATIC1111
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change import statements for #14478
上级
be5f1acc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
14 addition
and
17 deletion
+14
-17
modules/devices.py
modules/devices.py
+2
-2
modules/interrogate.py
modules/interrogate.py
+2
-3
modules/sd_models_xl.py
modules/sd_models_xl.py
+2
-2
modules/upscaler_utils.py
modules/upscaler_utils.py
+2
-3
modules/xlmr.py
modules/xlmr.py
+2
-2
modules/xlmr_m18.py
modules/xlmr_m18.py
+2
-3
test/test_torch_utils.py
test/test_torch_utils.py
+2
-2
未找到文件。
modules/devices.py
浏览文件 @
a70dfb64
...
...
@@ -4,7 +4,7 @@ from functools import lru_cache
import
torch
from
modules
import
errors
,
shared
from
modules
.torch_utils
import
get_param
from
modules
import
torch_utils
if
sys
.
platform
==
"darwin"
:
from
modules
import
mac_specific
...
...
@@ -132,7 +132,7 @@ patch_module_list = [
def
manual_cast_forward
(
self
,
*
args
,
**
kwargs
):
org_dtype
=
get_param
(
self
).
dtype
org_dtype
=
torch_utils
.
get_param
(
self
).
dtype
self
.
to
(
dtype
)
args
=
[
arg
.
to
(
dtype
)
if
isinstance
(
arg
,
torch
.
Tensor
)
else
arg
for
arg
in
args
]
kwargs
=
{
k
:
v
.
to
(
dtype
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
...
...
modules/interrogate.py
浏览文件 @
a70dfb64
...
...
@@ -10,8 +10,7 @@ import torch.hub
from
torchvision
import
transforms
from
torchvision.transforms.functional
import
InterpolationMode
from
modules
import
devices
,
paths
,
shared
,
lowvram
,
modelloader
,
errors
from
modules.torch_utils
import
get_param
from
modules
import
devices
,
paths
,
shared
,
lowvram
,
modelloader
,
errors
,
torch_utils
blip_image_eval_size
=
384
clip_model_name
=
'ViT-L/14'
...
...
@@ -132,7 +131,7 @@ class InterrogateModels:
self
.
clip_model
=
self
.
clip_model
.
to
(
devices
.
device_interrogate
)
self
.
dtype
=
get_param
(
self
.
clip_model
).
dtype
self
.
dtype
=
torch_utils
.
get_param
(
self
.
clip_model
).
dtype
def
send_clip_to_ram
(
self
):
if
not
shared
.
opts
.
interrogate_keep_models_in_memory
:
...
...
modules/sd_models_xl.py
浏览文件 @
a70dfb64
...
...
@@ -6,7 +6,7 @@ import sgm.models.diffusion
import
sgm.modules.diffusionmodules.denoiser_scaling
import
sgm.modules.diffusionmodules.discretizer
from
modules
import
devices
,
shared
,
prompt_parser
from
modules
.torch_utils
import
get_param
from
modules
import
torch_utils
def
get_learned_conditioning
(
self
:
sgm
.
models
.
diffusion
.
DiffusionEngine
,
batch
:
prompt_parser
.
SdConditioning
|
list
[
str
]):
...
...
@@ -91,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def
extend_sdxl
(
model
):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype
=
get_param
(
model
.
model
.
diffusion_model
).
dtype
dtype
=
torch_utils
.
get_param
(
model
.
model
.
diffusion_model
).
dtype
model
.
model
.
diffusion_model
.
dtype
=
dtype
model
.
model
.
conditioning_key
=
'crossattn'
model
.
cond_stage_key
=
'txt'
...
...
modules/upscaler_utils.py
浏览文件 @
a70dfb64
...
...
@@ -6,8 +6,7 @@ import torch
import
tqdm
from
PIL
import
Image
from
modules
import
images
,
shared
from
modules.torch_utils
import
get_param
from
modules
import
images
,
shared
,
torch_utils
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image):
img
=
np
.
ascontiguousarray
(
np
.
transpose
(
img
,
(
2
,
0
,
1
)))
/
255
img
=
torch
.
from_numpy
(
img
).
float
()
param
=
get_param
(
model
)
param
=
torch_utils
.
get_param
(
model
)
img
=
img
.
unsqueeze
(
0
).
to
(
device
=
param
.
device
,
dtype
=
param
.
dtype
)
with
torch
.
no_grad
():
...
...
modules/xlmr.py
浏览文件 @
a70dfb64
...
...
@@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from
transformers
import
XLMRobertaModel
,
XLMRobertaTokenizer
from
typing
import
Optional
from
modules
.torch_utils
import
get_param
from
modules
import
torch_utils
class
BertSeriesConfig
(
BertConfig
):
...
...
@@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self
.
post_init
()
def
encode
(
self
,
c
):
device
=
get_param
(
self
).
device
device
=
torch_utils
.
get_param
(
self
).
device
text
=
self
.
tokenizer
(
c
,
truncation
=
True
,
max_length
=
77
,
...
...
modules/xlmr_m18.py
浏览文件 @
a70dfb64
...
...
@@ -4,8 +4,7 @@ import torch
from
transformers.models.xlm_roberta.configuration_xlm_roberta
import
XLMRobertaConfig
from
transformers
import
XLMRobertaModel
,
XLMRobertaTokenizer
from
typing
import
Optional
from
modules.torch_utils
import
get_param
from
modules
import
torch_utils
class
BertSeriesConfig
(
BertConfig
):
...
...
@@ -71,7 +70,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self
.
post_init
()
def
encode
(
self
,
c
):
device
=
get_param
(
self
).
device
device
=
torch_utils
.
get_param
(
self
).
device
text
=
self
.
tokenizer
(
c
,
truncation
=
True
,
max_length
=
77
,
...
...
test/test_torch_utils.py
浏览文件 @
a70dfb64
...
...
@@ -3,7 +3,7 @@ import types
import
pytest
import
torch
from
modules
.torch_utils
import
get_param
from
modules
import
torch_utils
@
pytest
.
mark
.
parametrize
(
"wrapped"
,
[
True
,
False
])
...
...
@@ -14,6 +14,6 @@ def test_get_param(wrapped):
if
wrapped
:
# more or less how spandrel wraps a thing
mod
=
types
.
SimpleNamespace
(
model
=
mod
)
p
=
get_param
(
mod
)
p
=
torch_utils
.
get_param
(
mod
)
assert
p
.
dtype
==
torch
.
float16
assert
p
.
device
==
cpu
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录