Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Overbill1683
Stable Diffusion Webui
提交
5768afc7
S
Stable Diffusion Webui
项目概览
Overbill1683
/
Stable Diffusion Webui
大约 1 年 前同步成功
通知
1785
Star
81
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
分析
仓库
DevOps
项目成员
Pages
S
Stable Diffusion Webui
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Pages
分析
分析
仓库分析
DevOps
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
提交
体验新版 GitCode,发现更多精彩内容 >>
提交
5768afc7
编写于
12月 31, 2023
作者:
A
Aarni Koskela
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add utility to inspect a model's parameters (to get dtype/device)
上级
a84e8421
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
53 addition
and
7 deletion
+53
-7
modules/devices.py
modules/devices.py
+2
-1
modules/interrogate.py
modules/interrogate.py
+2
-1
modules/sd_models_xl.py
modules/sd_models_xl.py
+2
-1
modules/torch_utils.py
modules/torch_utils.py
+17
-0
modules/upscaler_utils.py
modules/upscaler_utils.py
+3
-2
modules/xlmr.py
modules/xlmr.py
+4
-1
modules/xlmr_m18.py
modules/xlmr_m18.py
+4
-1
test/test_torch_utils.py
test/test_torch_utils.py
+19
-0
未找到文件。
modules/devices.py
浏览文件 @
5768afc7
...
...
@@ -4,6 +4,7 @@ from functools import lru_cache
import
torch
from
modules
import
errors
,
shared
from
modules.torch_utils
import
get_param
if
sys
.
platform
==
"darwin"
:
from
modules
import
mac_specific
...
...
@@ -131,7 +132,7 @@ patch_module_list = [
def
manual_cast_forward
(
self
,
*
args
,
**
kwargs
):
org_dtype
=
next
(
self
.
parameters
()
).
dtype
org_dtype
=
get_param
(
self
).
dtype
self
.
to
(
dtype
)
args
=
[
arg
.
to
(
dtype
)
if
isinstance
(
arg
,
torch
.
Tensor
)
else
arg
for
arg
in
args
]
kwargs
=
{
k
:
v
.
to
(
dtype
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
...
...
modules/interrogate.py
浏览文件 @
5768afc7
...
...
@@ -11,6 +11,7 @@ from torchvision import transforms
from
torchvision.transforms.functional
import
InterpolationMode
from
modules
import
devices
,
paths
,
shared
,
lowvram
,
modelloader
,
errors
from
modules.torch_utils
import
get_param
blip_image_eval_size
=
384
clip_model_name
=
'ViT-L/14'
...
...
@@ -131,7 +132,7 @@ class InterrogateModels:
self
.
clip_model
=
self
.
clip_model
.
to
(
devices
.
device_interrogate
)
self
.
dtype
=
next
(
self
.
clip_model
.
parameters
()
).
dtype
self
.
dtype
=
get_param
(
self
.
clip_model
).
dtype
def
send_clip_to_ram
(
self
):
if
not
shared
.
opts
.
interrogate_keep_models_in_memory
:
...
...
modules/sd_models_xl.py
浏览文件 @
5768afc7
...
...
@@ -6,6 +6,7 @@ import sgm.models.diffusion
import
sgm.modules.diffusionmodules.denoiser_scaling
import
sgm.modules.diffusionmodules.discretizer
from
modules
import
devices
,
shared
,
prompt_parser
from
modules.torch_utils
import
get_param
def
get_learned_conditioning
(
self
:
sgm
.
models
.
diffusion
.
DiffusionEngine
,
batch
:
prompt_parser
.
SdConditioning
|
list
[
str
]):
...
...
@@ -90,7 +91,7 @@ sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt
def
extend_sdxl
(
model
):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype
=
next
(
model
.
model
.
diffusion_model
.
parameters
()
).
dtype
dtype
=
get_param
(
model
.
model
.
diffusion_model
).
dtype
model
.
model
.
diffusion_model
.
dtype
=
dtype
model
.
model
.
conditioning_key
=
'crossattn'
model
.
cond_stage_key
=
'txt'
...
...
modules/torch_utils.py
0 → 100644
浏览文件 @
5768afc7
from
__future__
import
annotations
import
torch.nn
def
get_param
(
model
)
->
torch
.
nn
.
Parameter
:
"""
Find the first parameter in a model or module.
"""
if
hasattr
(
model
,
"model"
)
and
hasattr
(
model
.
model
,
"parameters"
):
# Unpeel a model descriptor to get at the actual Torch module.
model
=
model
.
model
for
param
in
model
.
parameters
():
return
param
raise
ValueError
(
f
"No parameters found in model
{
model
!
r
}
"
)
modules/upscaler_utils.py
浏览文件 @
5768afc7
...
...
@@ -7,6 +7,7 @@ import tqdm
from
PIL
import
Image
from
modules
import
images
,
shared
from
modules.torch_utils
import
get_param
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image):
img
=
np
.
ascontiguousarray
(
np
.
transpose
(
img
,
(
2
,
0
,
1
)))
/
255
img
=
torch
.
from_numpy
(
img
).
float
()
model_weight
=
next
(
iter
(
model
.
model
.
parameters
())
)
img
=
img
.
unsqueeze
(
0
).
to
(
device
=
model_weight
.
device
,
dtype
=
model_weight
.
dtype
)
param
=
get_param
(
model
)
img
=
img
.
unsqueeze
(
0
).
to
(
device
=
param
.
device
,
dtype
=
param
.
dtype
)
with
torch
.
no_grad
():
output
=
model
(
img
)
...
...
modules/xlmr.py
浏览文件 @
5768afc7
...
...
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from
transformers
import
XLMRobertaModel
,
XLMRobertaTokenizer
from
typing
import
Optional
from
modules.torch_utils
import
get_param
class
BertSeriesConfig
(
BertConfig
):
def
__init__
(
self
,
vocab_size
=
30522
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
pad_token_id
=
0
,
position_embedding_type
=
"absolute"
,
use_cache
=
True
,
classifier_dropout
=
None
,
project_dim
=
512
,
pooler_fn
=
"average"
,
learn_encoder
=
False
,
model_type
=
'bert'
,
**
kwargs
):
...
...
@@ -62,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self
.
post_init
()
def
encode
(
self
,
c
):
device
=
next
(
self
.
parameters
()
).
device
device
=
get_param
(
self
).
device
text
=
self
.
tokenizer
(
c
,
truncation
=
True
,
max_length
=
77
,
...
...
modules/xlmr_m18.py
浏览文件 @
5768afc7
...
...
@@ -5,6 +5,9 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta
from
transformers
import
XLMRobertaModel
,
XLMRobertaTokenizer
from
typing
import
Optional
from
modules.torch_utils
import
get_param
class
BertSeriesConfig
(
BertConfig
):
def
__init__
(
self
,
vocab_size
=
30522
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
2
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-12
,
pad_token_id
=
0
,
position_embedding_type
=
"absolute"
,
use_cache
=
True
,
classifier_dropout
=
None
,
project_dim
=
512
,
pooler_fn
=
"average"
,
learn_encoder
=
False
,
model_type
=
'bert'
,
**
kwargs
):
...
...
@@ -68,7 +71,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
self
.
post_init
()
def
encode
(
self
,
c
):
device
=
next
(
self
.
parameters
()
).
device
device
=
get_param
(
self
).
device
text
=
self
.
tokenizer
(
c
,
truncation
=
True
,
max_length
=
77
,
...
...
test/test_torch_utils.py
0 → 100644
浏览文件 @
5768afc7
import
types
import
pytest
import
torch
from
modules.torch_utils
import
get_param
@
pytest
.
mark
.
parametrize
(
"wrapped"
,
[
True
,
False
])
def
test_get_param
(
wrapped
):
mod
=
torch
.
nn
.
Linear
(
1
,
1
)
cpu
=
torch
.
device
(
"cpu"
)
mod
.
to
(
dtype
=
torch
.
float16
,
device
=
cpu
)
if
wrapped
:
# more or less how spandrel wraps a thing
mod
=
types
.
SimpleNamespace
(
model
=
mod
)
p
=
get_param
(
mod
)
assert
p
.
dtype
==
torch
.
float16
assert
p
.
device
==
cpu
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录