未验证 提交 9acfceca 编写于 作者: J Javier 提交者: GitHub

Merge pull request #171 from jrzaurin/default-vision-models

Default vision models
......@@ -26,18 +26,19 @@ from pytorch_widedeep.models._base_wd_model_component import (
# googlenet
# inception
allowed_pretrained_models = [
"resnet",
"shufflenet",
"resnext",
"wide_resnet",
"regnet",
"densenet",
"mobilenet",
"mnasnet",
"efficientnet",
"squeezenet",
]
# {Arch: Default}
allowed_pretrained_models = {
"resnet": "resnet18",
"shufflenet": "shufflenet_v2_x0_5",
"resnext": "resnext50_32x4d",
"wide_resnet": "wide_resnet50_2",
"regnet": "regnet_x_1_6gf",
"densenet": "densenet121",
"mobilenet": "mobilenet_v2",
"mnasnet": "mnasnet1_0",
"efficientnet": "efficientnet_b0",
"squeezenet": "squeezenet1_0",
}
class Vision(BaseWDModelComponent):
......@@ -199,22 +200,33 @@ class Vision(BaseWDModelComponent):
def _get_features(self) -> Tuple[nn.Module, int]:
if self.pretrained_model_setup is not None:
if isinstance(self.pretrained_model_setup, str):
try:
pretrained_model = torchvision.models.__dict__[
self.pretrained_model_setup
](weights="IMAGENET1K_V2")
except KeyError:
if self.pretrained_model_setup in allowed_pretrained_models.keys():
model = allowed_pretrained_models[self.pretrained_model_setup]
pretrained_model = torchvision.models.__dict__[model](
weights=torchvision.models.get_model_weights(model).DEFAULT
)
warnings.warn(
f"{self.pretrained_model_setup} defaulting to {model}",
UserWarning,
)
else:
pretrained_model = torchvision.models.__dict__[
self.pretrained_model_setup
](weights="IMAGENET1K_V1")
elif isinstance(self.pretrained_model_setup, Dict):
model_name = list(self.pretrained_model_setup.keys())[0]
model_name = next(iter(self.pretrained_model_setup))
model_weights = self.pretrained_model_setup[model_name]
if model_name in allowed_pretrained_models.keys():
model_name = allowed_pretrained_models[model_name]
pretrained_model = torchvision.models.__dict__[model_name](
weights=model_weights
)
output_dim: int = self.get_backbone_output_dim(pretrained_model)
features = nn.Sequential(*(list(pretrained_model.children())[:-1]))
else:
features = self._basic_cnn()
output_dim = self.channel_sizes[-1]
......@@ -297,7 +309,7 @@ class Vision(BaseWDModelComponent):
if not valid_pretrained_model_name:
raise ValueError(
f"{pretrained_model_setup} is not among the allowed pretrained models."
f" These are {allowed_pretrained_models}. Please choose a variant of these architectures"
f" These are {allowed_pretrained_models.keys()}. Please choose a variant of these architectures"
)
if n_trainable is not None and trainable_params is not None:
raise UserWarning(
......
......@@ -57,7 +57,7 @@ def test_n_trainable():
({"squeezenet1_0": SqueezeNet1_0_Weights.IMAGENET1K_V1}, 512),
],
)
def test_archiectures(arch, expected_out_shape):
def test_architectures(arch, expected_out_shape):
model = Vision(pretrained_model_setup=arch, n_trainable=0)
out = model(X_images)
assert out.size(0) == 10 and out.size(1) == expected_out_shape
......@@ -85,3 +85,30 @@ def test_all_frozen():
for p in model.parameters():
is_trainable.append(not p.requires_grad)
assert all(is_trainable)
###############################################################################
# Check defaulting for arch classes
###############################################################################
@pytest.mark.parametrize(
"arch, expected_out_shape",
[
("resnet", 512),
("shufflenet", 1024),
("resnext", 2048),
("wide_resnet", 2048),
("regnet", 912),
("mobilenet", 1280),
("mnasnet", 1280),
("efficientnet", 1280),
("squeezenet", 512),
({"shufflenet": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1}, 1024),
({"resnext": ResNeXt50_32X4D_Weights.IMAGENET1K_V2}, 2048),
],
)
def test_pretrained_model_setup_defaults(arch, expected_out_shape):
model = Vision(pretrained_model_setup=arch, n_trainable=0)
out = model(X_images)
assert out.size(0) == 10 and out.size(1) == expected_out_shape
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册