Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Pytorch Widedeep
提交
9acfceca
P
Pytorch Widedeep
项目概览
Greenplum
/
Pytorch Widedeep
大约 1 年 前同步成功
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Pytorch Widedeep
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
9acfceca
编写于
6月 15, 2023
作者:
J
Javier
提交者:
GitHub
6月 15, 2023
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #171 from jrzaurin/default-vision-models
Default vision models
上级
05e007e9
bda30f03
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
59 addition
and
20 deletion
+59
-20
pytorch_widedeep/models/image/vision.py
pytorch_widedeep/models/image/vision.py
+31
-19
tests/test_model_components/test_mc_image.py
tests/test_model_components/test_mc_image.py
+28
-1
未找到文件。
pytorch_widedeep/models/image/vision.py
浏览文件 @
9acfceca
...
...
@@ -26,18 +26,19 @@ from pytorch_widedeep.models._base_wd_model_component import (
# googlenet
# inception
allowed_pretrained_models
=
[
"resnet"
,
"shufflenet"
,
"resnext"
,
"wide_resnet"
,
"regnet"
,
"densenet"
,
"mobilenet"
,
"mnasnet"
,
"efficientnet"
,
"squeezenet"
,
]
# {Arch: Default}
allowed_pretrained_models
=
{
"resnet"
:
"resnet18"
,
"shufflenet"
:
"shufflenet_v2_x0_5"
,
"resnext"
:
"resnext50_32x4d"
,
"wide_resnet"
:
"wide_resnet50_2"
,
"regnet"
:
"regnet_x_1_6gf"
,
"densenet"
:
"densenet121"
,
"mobilenet"
:
"mobilenet_v2"
,
"mnasnet"
:
"mnasnet1_0"
,
"efficientnet"
:
"efficientnet_b0"
,
"squeezenet"
:
"squeezenet1_0"
,
}
class
Vision
(
BaseWDModelComponent
):
...
...
@@ -199,22 +200,33 @@ class Vision(BaseWDModelComponent):
def
_get_features
(
self
)
->
Tuple
[
nn
.
Module
,
int
]:
if
self
.
pretrained_model_setup
is
not
None
:
if
isinstance
(
self
.
pretrained_model_setup
,
str
):
try
:
pretrained_model
=
torchvision
.
models
.
__dict__
[
self
.
pretrained_model_setup
](
weights
=
"IMAGENET1K_V2"
)
except
KeyError
:
if
self
.
pretrained_model_setup
in
allowed_pretrained_models
.
keys
():
model
=
allowed_pretrained_models
[
self
.
pretrained_model_setup
]
pretrained_model
=
torchvision
.
models
.
__dict__
[
model
](
weights
=
torchvision
.
models
.
get_model_weights
(
model
).
DEFAULT
)
warnings
.
warn
(
f
"
{
self
.
pretrained_model_setup
}
defaulting to
{
model
}
"
,
UserWarning
,
)
else
:
pretrained_model
=
torchvision
.
models
.
__dict__
[
self
.
pretrained_model_setup
](
weights
=
"IMAGENET1K_V1"
)
elif
isinstance
(
self
.
pretrained_model_setup
,
Dict
):
model_name
=
list
(
self
.
pretrained_model_setup
.
keys
())[
0
]
model_name
=
next
(
iter
(
self
.
pretrained_model_setup
))
model_weights
=
self
.
pretrained_model_setup
[
model_name
]
if
model_name
in
allowed_pretrained_models
.
keys
():
model_name
=
allowed_pretrained_models
[
model_name
]
pretrained_model
=
torchvision
.
models
.
__dict__
[
model_name
](
weights
=
model_weights
)
output_dim
:
int
=
self
.
get_backbone_output_dim
(
pretrained_model
)
features
=
nn
.
Sequential
(
*
(
list
(
pretrained_model
.
children
())[:
-
1
]))
else
:
features
=
self
.
_basic_cnn
()
output_dim
=
self
.
channel_sizes
[
-
1
]
...
...
@@ -297,7 +309,7 @@ class Vision(BaseWDModelComponent):
if
not
valid_pretrained_model_name
:
raise
ValueError
(
f
"
{
pretrained_model_setup
}
is not among the allowed pretrained models."
f
" These are
{
allowed_pretrained_models
}
. Please choose a variant of these architectures"
f
" These are
{
allowed_pretrained_models
.
keys
()
}
. Please choose a variant of these architectures"
)
if
n_trainable
is
not
None
and
trainable_params
is
not
None
:
raise
UserWarning
(
...
...
tests/test_model_components/test_mc_image.py
浏览文件 @
9acfceca
...
...
@@ -57,7 +57,7 @@ def test_n_trainable():
({
"squeezenet1_0"
:
SqueezeNet1_0_Weights
.
IMAGENET1K_V1
},
512
),
],
)
def
test_archiectures
(
arch
,
expected_out_shape
):
def
test_archi
t
ectures
(
arch
,
expected_out_shape
):
model
=
Vision
(
pretrained_model_setup
=
arch
,
n_trainable
=
0
)
out
=
model
(
X_images
)
assert
out
.
size
(
0
)
==
10
and
out
.
size
(
1
)
==
expected_out_shape
...
...
@@ -85,3 +85,30 @@ def test_all_frozen():
for
p
in
model
.
parameters
():
is_trainable
.
append
(
not
p
.
requires_grad
)
assert
all
(
is_trainable
)
###############################################################################
# Check defaulting for arch classes
###############################################################################
@
pytest
.
mark
.
parametrize
(
"arch, expected_out_shape"
,
[
(
"resnet"
,
512
),
(
"shufflenet"
,
1024
),
(
"resnext"
,
2048
),
(
"wide_resnet"
,
2048
),
(
"regnet"
,
912
),
(
"mobilenet"
,
1280
),
(
"mnasnet"
,
1280
),
(
"efficientnet"
,
1280
),
(
"squeezenet"
,
512
),
({
"shufflenet"
:
ShuffleNet_V2_X0_5_Weights
.
IMAGENET1K_V1
},
1024
),
({
"resnext"
:
ResNeXt50_32X4D_Weights
.
IMAGENET1K_V2
},
2048
),
],
)
def
test_pretrained_model_setup_defaults
(
arch
,
expected_out_shape
):
model
=
Vision
(
pretrained_model_setup
=
arch
,
n_trainable
=
0
)
out
=
model
(
X_images
)
assert
out
.
size
(
0
)
==
10
and
out
.
size
(
1
)
==
expected_out_shape
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录