From 4b26fc422c7ad6c90a1dd269c6fb0f27dcd81bd7 Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Tue, 6 Apr 2021 12:01:02 +0800 Subject: [PATCH] reorg `_load_pretrained_parameters` --- hubconf.py | 186 ++++++++++++++--------------------------------------- 1 file changed, 47 insertions(+), 139 deletions(-) diff --git a/hubconf.py b/hubconf.py index a843cba6..a779cf72 100644 --- a/hubconf.py +++ b/hubconf.py @@ -81,9 +81,7 @@ def AlexNet(pretrained=False, **kwargs): model = _alexnet.AlexNet(**kwargs) if pretrained: - assert 'AlexNet' in _checkpoints, 'Not provide `AlexNet` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['AlexNet']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'AlexNet') return model @@ -96,9 +94,7 @@ def VGG11(pretrained=False, **kwargs): model = _vgg.VGG11(**kwargs) if pretrained: - assert 'VGG11' in _checkpoints, 'Not provide `VGG11` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG11']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'VGG11') return model @@ -110,9 +106,7 @@ def VGG13(pretrained=False, **kwargs): model = _vgg.VGG13(**kwargs) if pretrained: - assert 'VGG13' in _checkpoints, 'Not provide `VGG13` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG13']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'VGG13') return model @@ -124,9 +118,7 @@ def VGG16(pretrained=False, **kwargs): model = _vgg.VGG16(**kwargs) if pretrained: - assert 'VGG16' in _checkpoints, 'Not provide `VGG16` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG16']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'VGG16') return model @@ -138,9 +130,7 @@ def VGG19(pretrained=False, **kwargs): model = _vgg.VGG19(**kwargs) if pretrained: - assert 'VGG19' in _checkpoints, 'Not provide `VGG19` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG19']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'VGG19') return model @@ -154,9 +144,7 @@ def ResNet18(pretrained=False, **kwargs): model = _resnet.ResNet18(**kwargs) if pretrained: - assert 'ResNet18' in _checkpoints, 'Not provide `ResNet18` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet18']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNet18') return model @@ -168,9 +156,7 @@ def ResNet34(pretrained=False, **kwargs): model = _resnet.ResNet34(**kwargs) if pretrained: - assert 'ResNet34' in _checkpoints, 'Not provide `ResNet34` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet34']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNet34') return model @@ -182,10 +168,8 @@ def ResNet50(pretrained=False, **kwargs): model = _resnet.ResNet50(**kwargs) if pretrained: - assert 'ResNet50' in _checkpoints, 'Not provide `ResNet50` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet50']) - model.set_state_dict(paddle.load(path)) - + model = _load_pretrained_parameters(model, 'ResNet50') + return model @@ -196,9 +180,7 @@ def ResNet101(pretrained=False, **kwargs): model = _resnet.ResNet101(**kwargs) if pretrained: - assert 'ResNet101' in _checkpoints, 'Not provide `ResNet101` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet101']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNet101') return model @@ -210,9 +192,7 @@ def ResNet152(pretrained=False, **kwargs): model = _resnet.ResNet152(**kwargs) if pretrained: - assert 'ResNet152' in _checkpoints, 'Not provide `ResNet152` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet152']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNet152') return model @@ -225,9 +205,7 @@ def SqueezeNet1_0(pretrained=False, **kwargs): model = _squeezenet.SqueezeNet1_0(**kwargs) if pretrained: - assert 'SqueezeNet1_0' in _checkpoints, 'Not provide `SqueezeNet1_0` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_0']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'SqueezeNet1_0') return model @@ -239,9 +217,7 @@ def SqueezeNet1_1(pretrained=False, **kwargs): model = _squeezenet.SqueezeNet1_1(**kwargs) if pretrained: - assert 'SqueezeNet1_1' in _checkpoints, 'Not provide `SqueezeNet1_1` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_1']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'SqueezeNet1_1') return model @@ -255,9 +231,7 @@ def DenseNet121(pretrained=False, **kwargs): model = _densenet.DenseNet121(**kwargs) if pretrained: - assert 'DenseNet121' in _checkpoints, 'Not provide `DenseNet121` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet121']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'DenseNet121') return model @@ -269,9 +243,7 @@ def DenseNet161(pretrained=False, **kwargs): model = _densenet.DenseNet161(**kwargs) if pretrained: - assert 'DenseNet161' in _checkpoints, 'Not provide `DenseNet161` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet161']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'DenseNet161') return model @@ -283,9 +255,7 @@ def DenseNet169(pretrained=False, **kwargs): model = _densenet.DenseNet169(**kwargs) if pretrained: - assert 'DenseNet169' in _checkpoints, 'Not provide `DenseNet169` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet169']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'DenseNet169') return model @@ -297,9 +267,7 @@ def DenseNet201(pretrained=False, **kwargs): model = _densenet.DenseNet201(**kwargs) if pretrained: - assert 'DenseNet201' in _checkpoints, 'Not provide `DenseNet201` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet201']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'DenseNet201') return model @@ -311,9 +279,7 @@ def DenseNet264(pretrained=False, **kwargs): model = _densenet.DenseNet264(**kwargs) if pretrained: - assert 'DenseNet264' in _checkpoints, 'Not provide `DenseNet264` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet264']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'DenseNet264') return model @@ -326,9 +292,7 @@ def InceptionV3(pretrained=False, **kwargs): model = _inception_v3.InceptionV3(**kwargs) if pretrained: - assert 'InceptionV3' in _checkpoints, 'Not provide `InceptionV3` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['InceptionV3']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'InceptionV3') return model @@ -340,9 +304,7 @@ def InceptionV4(pretrained=False, **kwargs): model = _inception_v4.InceptionV4(**kwargs) if pretrained: - assert 'InceptionV4' in _checkpoints, 'Not provide `InceptionV4` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['InceptionV4']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'InceptionV4') return model @@ -355,9 +317,7 @@ def GoogLeNet(pretrained=False, **kwargs): model = _googlenet.GoogLeNet(**kwargs) if pretrained: - assert 'GoogLeNet' in _checkpoints, 'Not provide `GoogLeNet` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['GoogLeNet']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'GoogLeNet') return model @@ -370,9 +330,7 @@ def ShuffleNet(pretrained=False, **kwargs): model = _shufflenet_v2.ShuffleNet(**kwargs) if pretrained: - assert 'ShuffleNet' in _checkpoints, 'Not provide `ShuffleNet` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ShuffleNet']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ShuffleNet') return model @@ -385,9 +343,7 @@ def MobileNetV1(pretrained=False, **kwargs): model = _mobilenet_v1.MobileNetV1(**kwargs) if pretrained: - assert 'MobileNetV1' in _checkpoints, 'Not provide `MobileNetV1` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV1') return model @@ -399,9 +355,7 @@ def MobileNetV1_x0_25(pretrained=False, **kwargs): model = _mobilenet_v1.MobileNetV1_x0_25(**kwargs) if pretrained: - assert 'MobileNetV1_x0_25' in _checkpoints, 'Not provide `MobileNetV1_x0_25` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_25']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV1_x0_25') return model @@ -413,9 +367,7 @@ def MobileNetV1_x0_5(pretrained=False, **kwargs): model = _mobilenet_v1.MobileNetV1_x0_5(**kwargs) if pretrained: - assert 'MobileNetV1_x0_5' in _checkpoints, 'Not provide `MobileNetV1_x0_5` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_5']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV1_x0_5') return model @@ -427,9 +379,7 @@ def MobileNetV1_x0_75(pretrained=False, **kwargs): model = _mobilenet_v1.MobileNetV1_x0_75(**kwargs) if pretrained: - assert 'MobileNetV1_x0_75' in _checkpoints, 'Not provide `MobileNetV1_x0_75` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_75']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV1_x0_75') return model @@ -441,9 +391,7 @@ def MobileNetV2_x0_25(pretrained=False, **kwargs): model = _mobilenet_v2.MobileNetV2_x0_25(**kwargs) if pretrained: - assert 'MobileNetV2_x0_25' in _checkpoints, 'Not provide `MobileNetV2_x0_25` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_25']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV2_x0_25') return model @@ -455,9 +403,7 @@ def MobileNetV2_x0_5(pretrained=False, **kwargs): model = _mobilenet_v2.MobileNetV2_x0_5(**kwargs) if pretrained: - assert 'MobileNetV2_x0_5' in _checkpoints, 'Not provide `MobileNetV2_x0_5` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_5']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV2_x0_5') return model @@ -469,9 +415,7 @@ def MobileNetV2_x0_75(pretrained=False, **kwargs): model = _mobilenet_v2.MobileNetV2_x0_75(**kwargs) if pretrained: - assert 'MobileNetV2_x0_75' in _checkpoints, 'Not provide `MobileNetV2_x0_75` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_75']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV2_x0_75') return model @@ -483,9 +427,7 @@ def MobileNetV2_x1_5(pretrained=False, **kwargs): model = _mobilenet_v2.MobileNetV2_x1_5(**kwargs) if pretrained: - assert 'MobileNetV2_x1_5' in _checkpoints, 'Not provide `MobileNetV2_x1_5` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x1_5']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV2_x1_5') return model @@ -497,9 +439,7 @@ def MobileNetV2_x2_0(pretrained=False, **kwargs): model = _mobilenet_v2.MobileNetV2_x2_0(**kwargs) if pretrained: - assert 'MobileNetV2_x2_0' in _checkpoints, 'Not provide `MobileNetV2_x2_0` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x2_0']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV2_x2_0') return model @@ -511,9 +451,7 @@ def MobileNetV3_large_x0_35(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_large_x0_35(**kwargs) if pretrained: - assert 'MobileNetV3_large_x0_35' in _checkpoints, 'Not provide `MobileNetV3_large_x0_35` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_35']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_35') return model @@ -525,9 +463,7 @@ def MobileNetV3_large_x0_5(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_large_x0_5(**kwargs) if pretrained: - assert 'MobileNetV3_large_x0_5' in _checkpoints, 'Not provide `MobileNetV3_large_x0_5` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_5']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_5') return model @@ -539,9 +475,7 @@ def MobileNetV3_large_x0_75(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_large_x0_75(**kwargs) if pretrained: - assert 'MobileNetV3_large_x0_75' in _checkpoints, 'Not provide `MobileNetV3_large_x0_75` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_75']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_75') return model @@ -553,9 +487,7 @@ def MobileNetV3_large_x1_0(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_large_x1_0(**kwargs) if pretrained: - assert 'MobileNetV3_large_x1_0' in _checkpoints, 'Not provide `MobileNetV3_large_x1_0` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x1_0']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_large_x1_0') return model @@ -567,9 +499,7 @@ def MobileNetV3_large_x1_25(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_large_x1_25(**kwargs) if pretrained: - assert 'MobileNetV3_large_x1_25' in _checkpoints, 'Not provide `MobileNetV3_large_x1_25` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x1_25']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_large_x1_25') return model @@ -581,9 +511,7 @@ def MobileNetV3_small_x0_35(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_small_x0_35(**kwargs) if pretrained: - assert 'MobileNetV3_small_x0_35' in _checkpoints, 'Not provide `MobileNetV3_small_x0_35` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_35']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_35') return model @@ -595,9 +523,7 @@ def MobileNetV3_small_x0_5(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_small_x0_5(**kwargs) if pretrained: - assert 'MobileNetV3_small_x0_5' in _checkpoints, 'Not provide `MobileNetV3_small_x0_5` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_5']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_5') return model @@ -609,9 +535,7 @@ def MobileNetV3_small_x0_75(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_small_x0_75(**kwargs) if pretrained: - assert 'MobileNetV3_small_x0_75' in _checkpoints, 'Not provide `MobileNetV3_small_x0_75` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_75']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_75') return model @@ -623,9 +547,7 @@ def MobileNetV3_small_x1_0(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_small_x1_0(**kwargs) if pretrained: - assert 'MobileNetV3_small_x1_0' in _checkpoints, 'Not provide `MobileNetV3_small_x1_0` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x1_0']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_small_x1_0') return model @@ -637,9 +559,7 @@ def MobileNetV3_small_x1_25(pretrained=False, **kwargs): model = _mobilenet_v3.MobileNetV3_small_x1_25(**kwargs) if pretrained: - assert 'MobileNetV3_small_x1_25' in _checkpoints, 'Not provide `MobileNetV3_small_x1_25` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x1_25']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'MobileNetV3_small_x1_25') return model @@ -652,9 +572,7 @@ def ResNeXt101_32x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt101_32x4d(**kwargs) if pretrained: - assert 'ResNeXt101_32x4d' in _checkpoints, 'Not provide `ResNeXt101_32x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt101_32x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt101_32x4d') return model @@ -666,9 +584,7 @@ def ResNeXt101_64x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt101_64x4d(**kwargs) if pretrained: - assert 'ResNeXt101_64x4d' in _checkpoints, 'Not provide `ResNeXt101_64x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt101_64x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt101_64x4d') return model @@ -680,9 +596,7 @@ def ResNeXt152_32x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt152_32x4d(**kwargs) if pretrained: - assert 'ResNeXt152_32x4d' in _checkpoints, 'Not provide `ResNeXt152_32x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt152_32x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt152_32x4d') return model @@ -694,9 +608,7 @@ def ResNeXt152_64x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt152_64x4d(**kwargs) if pretrained: - assert 'ResNeXt152_64x4d' in _checkpoints, 'Not provide `ResNeXt152_64x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt152_64x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt152_64x4d') return model @@ -708,9 +620,7 @@ def ResNeXt50_32x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt50_32x4d(**kwargs) if pretrained: - assert 'ResNeXt50_32x4d' in _checkpoints, 'Not provide `ResNeXt50_32x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt50_32x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt50_32x4d') return model @@ -722,8 +632,6 @@ def ResNeXt50_64x4d(pretrained=False, **kwargs): model = _resnext.ResNeXt50_64x4d(**kwargs) if pretrained: - assert 'ResNeXt50_64x4d' in _checkpoints, 'Not provide `ResNeXt50_64x4d` pretrained model.' - path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt50_64x4d']) - model.set_state_dict(paddle.load(path)) + model = _load_pretrained_parameters(model, 'ResNeXt50_64x4d') return model \ No newline at end of file -- GitLab