提交 4b26fc42 编写于 作者: L lyuwenyu

reorg `_load_pretrained_parameters`

上级 bdd8178c
......@@ -81,9 +81,7 @@ def AlexNet(pretrained=False, **kwargs):
model = _alexnet.AlexNet(**kwargs)
if pretrained:
assert 'AlexNet' in _checkpoints, 'Not provide `AlexNet` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['AlexNet'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'AlexNet')
return model
......@@ -96,9 +94,7 @@ def VGG11(pretrained=False, **kwargs):
model = _vgg.VGG11(**kwargs)
if pretrained:
assert 'VGG11' in _checkpoints, 'Not provide `VGG11` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG11'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'VGG11')
return model
......@@ -110,9 +106,7 @@ def VGG13(pretrained=False, **kwargs):
model = _vgg.VGG13(**kwargs)
if pretrained:
assert 'VGG13' in _checkpoints, 'Not provide `VGG13` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG13'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'VGG13')
return model
......@@ -124,9 +118,7 @@ def VGG16(pretrained=False, **kwargs):
model = _vgg.VGG16(**kwargs)
if pretrained:
assert 'VGG16' in _checkpoints, 'Not provide `VGG16` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG16'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'VGG16')
return model
......@@ -138,9 +130,7 @@ def VGG19(pretrained=False, **kwargs):
model = _vgg.VGG19(**kwargs)
if pretrained:
assert 'VGG19' in _checkpoints, 'Not provide `VGG19` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG19'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'VGG19')
return model
......@@ -154,9 +144,7 @@ def ResNet18(pretrained=False, **kwargs):
model = _resnet.ResNet18(**kwargs)
if pretrained:
assert 'ResNet18' in _checkpoints, 'Not provide `ResNet18` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet18'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNet18')
return model
......@@ -168,9 +156,7 @@ def ResNet34(pretrained=False, **kwargs):
model = _resnet.ResNet34(**kwargs)
if pretrained:
assert 'ResNet34' in _checkpoints, 'Not provide `ResNet34` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet34'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNet34')
return model
......@@ -182,10 +168,8 @@ def ResNet50(pretrained=False, **kwargs):
model = _resnet.ResNet50(**kwargs)
if pretrained:
assert 'ResNet50' in _checkpoints, 'Not provide `ResNet50` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet50'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNet50')
return model
......@@ -196,9 +180,7 @@ def ResNet101(pretrained=False, **kwargs):
model = _resnet.ResNet101(**kwargs)
if pretrained:
assert 'ResNet101' in _checkpoints, 'Not provide `ResNet101` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet101'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNet101')
return model
......@@ -210,9 +192,7 @@ def ResNet152(pretrained=False, **kwargs):
model = _resnet.ResNet152(**kwargs)
if pretrained:
assert 'ResNet152' in _checkpoints, 'Not provide `ResNet152` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet152'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNet152')
return model
......@@ -225,9 +205,7 @@ def SqueezeNet1_0(pretrained=False, **kwargs):
model = _squeezenet.SqueezeNet1_0(**kwargs)
if pretrained:
assert 'SqueezeNet1_0' in _checkpoints, 'Not provide `SqueezeNet1_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_0'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'SqueezeNet1_0')
return model
......@@ -239,9 +217,7 @@ def SqueezeNet1_1(pretrained=False, **kwargs):
model = _squeezenet.SqueezeNet1_1(**kwargs)
if pretrained:
assert 'SqueezeNet1_1' in _checkpoints, 'Not provide `SqueezeNet1_1` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['SqueezeNet1_1'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'SqueezeNet1_1')
return model
......@@ -255,9 +231,7 @@ def DenseNet121(pretrained=False, **kwargs):
model = _densenet.DenseNet121(**kwargs)
if pretrained:
assert 'DenseNet121' in _checkpoints, 'Not provide `DenseNet121` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet121'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'DenseNet121')
return model
......@@ -269,9 +243,7 @@ def DenseNet161(pretrained=False, **kwargs):
model = _densenet.DenseNet161(**kwargs)
if pretrained:
assert 'DenseNet161' in _checkpoints, 'Not provide `DenseNet161` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet161'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'DenseNet161')
return model
......@@ -283,9 +255,7 @@ def DenseNet169(pretrained=False, **kwargs):
model = _densenet.DenseNet169(**kwargs)
if pretrained:
assert 'DenseNet169' in _checkpoints, 'Not provide `DenseNet169` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet169'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'DenseNet169')
return model
......@@ -297,9 +267,7 @@ def DenseNet201(pretrained=False, **kwargs):
model = _densenet.DenseNet201(**kwargs)
if pretrained:
assert 'DenseNet201' in _checkpoints, 'Not provide `DenseNet201` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet201'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'DenseNet201')
return model
......@@ -311,9 +279,7 @@ def DenseNet264(pretrained=False, **kwargs):
model = _densenet.DenseNet264(**kwargs)
if pretrained:
assert 'DenseNet264' in _checkpoints, 'Not provide `DenseNet264` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['DenseNet264'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'DenseNet264')
return model
......@@ -326,9 +292,7 @@ def InceptionV3(pretrained=False, **kwargs):
model = _inception_v3.InceptionV3(**kwargs)
if pretrained:
assert 'InceptionV3' in _checkpoints, 'Not provide `InceptionV3` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['InceptionV3'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'InceptionV3')
return model
......@@ -340,9 +304,7 @@ def InceptionV4(pretrained=False, **kwargs):
model = _inception_v4.InceptionV4(**kwargs)
if pretrained:
assert 'InceptionV4' in _checkpoints, 'Not provide `InceptionV4` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['InceptionV4'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'InceptionV4')
return model
......@@ -355,9 +317,7 @@ def GoogLeNet(pretrained=False, **kwargs):
model = _googlenet.GoogLeNet(**kwargs)
if pretrained:
assert 'GoogLeNet' in _checkpoints, 'Not provide `GoogLeNet` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['GoogLeNet'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'GoogLeNet')
return model
......@@ -370,9 +330,7 @@ def ShuffleNet(pretrained=False, **kwargs):
model = _shufflenet_v2.ShuffleNet(**kwargs)
if pretrained:
assert 'ShuffleNet' in _checkpoints, 'Not provide `ShuffleNet` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ShuffleNet'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ShuffleNet')
return model
......@@ -385,9 +343,7 @@ def MobileNetV1(pretrained=False, **kwargs):
model = _mobilenet_v1.MobileNetV1(**kwargs)
if pretrained:
assert 'MobileNetV1' in _checkpoints, 'Not provide `MobileNetV1` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV1')
return model
......@@ -399,9 +355,7 @@ def MobileNetV1_x0_25(pretrained=False, **kwargs):
model = _mobilenet_v1.MobileNetV1_x0_25(**kwargs)
if pretrained:
assert 'MobileNetV1_x0_25' in _checkpoints, 'Not provide `MobileNetV1_x0_25` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_25'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_25')
return model
......@@ -413,9 +367,7 @@ def MobileNetV1_x0_5(pretrained=False, **kwargs):
model = _mobilenet_v1.MobileNetV1_x0_5(**kwargs)
if pretrained:
assert 'MobileNetV1_x0_5' in _checkpoints, 'Not provide `MobileNetV1_x0_5` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_5'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_5')
return model
......@@ -427,9 +379,7 @@ def MobileNetV1_x0_75(pretrained=False, **kwargs):
model = _mobilenet_v1.MobileNetV1_x0_75(**kwargs)
if pretrained:
assert 'MobileNetV1_x0_75' in _checkpoints, 'Not provide `MobileNetV1_x0_75` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV1_x0_75'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV1_x0_75')
return model
......@@ -441,9 +391,7 @@ def MobileNetV2_x0_25(pretrained=False, **kwargs):
model = _mobilenet_v2.MobileNetV2_x0_25(**kwargs)
if pretrained:
assert 'MobileNetV2_x0_25' in _checkpoints, 'Not provide `MobileNetV2_x0_25` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_25'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_25')
return model
......@@ -455,9 +403,7 @@ def MobileNetV2_x0_5(pretrained=False, **kwargs):
model = _mobilenet_v2.MobileNetV2_x0_5(**kwargs)
if pretrained:
assert 'MobileNetV2_x0_5' in _checkpoints, 'Not provide `MobileNetV2_x0_5` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_5'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_5')
return model
......@@ -469,9 +415,7 @@ def MobileNetV2_x0_75(pretrained=False, **kwargs):
model = _mobilenet_v2.MobileNetV2_x0_75(**kwargs)
if pretrained:
assert 'MobileNetV2_x0_75' in _checkpoints, 'Not provide `MobileNetV2_x0_75` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x0_75'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV2_x0_75')
return model
......@@ -483,9 +427,7 @@ def MobileNetV2_x1_5(pretrained=False, **kwargs):
model = _mobilenet_v2.MobileNetV2_x1_5(**kwargs)
if pretrained:
assert 'MobileNetV2_x1_5' in _checkpoints, 'Not provide `MobileNetV2_x1_5` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x1_5'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV2_x1_5')
return model
......@@ -497,9 +439,7 @@ def MobileNetV2_x2_0(pretrained=False, **kwargs):
model = _mobilenet_v2.MobileNetV2_x2_0(**kwargs)
if pretrained:
assert 'MobileNetV2_x2_0' in _checkpoints, 'Not provide `MobileNetV2_x2_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV2_x2_0'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV2_x2_0')
return model
......@@ -511,9 +451,7 @@ def MobileNetV3_large_x0_35(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_large_x0_35(**kwargs)
if pretrained:
assert 'MobileNetV3_large_x0_35' in _checkpoints, 'Not provide `MobileNetV3_large_x0_35` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_35'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_35')
return model
......@@ -525,9 +463,7 @@ def MobileNetV3_large_x0_5(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_large_x0_5(**kwargs)
if pretrained:
assert 'MobileNetV3_large_x0_5' in _checkpoints, 'Not provide `MobileNetV3_large_x0_5` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_5'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_5')
return model
......@@ -539,9 +475,7 @@ def MobileNetV3_large_x0_75(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_large_x0_75(**kwargs)
if pretrained:
assert 'MobileNetV3_large_x0_75' in _checkpoints, 'Not provide `MobileNetV3_large_x0_75` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x0_75'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_large_x0_75')
return model
......@@ -553,9 +487,7 @@ def MobileNetV3_large_x1_0(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_large_x1_0(**kwargs)
if pretrained:
assert 'MobileNetV3_large_x1_0' in _checkpoints, 'Not provide `MobileNetV3_large_x1_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x1_0'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_large_x1_0')
return model
......@@ -567,9 +499,7 @@ def MobileNetV3_large_x1_25(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_large_x1_25(**kwargs)
if pretrained:
assert 'MobileNetV3_large_x1_25' in _checkpoints, 'Not provide `MobileNetV3_large_x1_25` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_large_x1_25'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_large_x1_25')
return model
......@@ -581,9 +511,7 @@ def MobileNetV3_small_x0_35(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_small_x0_35(**kwargs)
if pretrained:
assert 'MobileNetV3_small_x0_35' in _checkpoints, 'Not provide `MobileNetV3_small_x0_35` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_35'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_35')
return model
......@@ -595,9 +523,7 @@ def MobileNetV3_small_x0_5(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_small_x0_5(**kwargs)
if pretrained:
assert 'MobileNetV3_small_x0_5' in _checkpoints, 'Not provide `MobileNetV3_small_x0_5` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_5'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_5')
return model
......@@ -609,9 +535,7 @@ def MobileNetV3_small_x0_75(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_small_x0_75(**kwargs)
if pretrained:
assert 'MobileNetV3_small_x0_75' in _checkpoints, 'Not provide `MobileNetV3_small_x0_75` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x0_75'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_small_x0_75')
return model
......@@ -623,9 +547,7 @@ def MobileNetV3_small_x1_0(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_small_x1_0(**kwargs)
if pretrained:
assert 'MobileNetV3_small_x1_0' in _checkpoints, 'Not provide `MobileNetV3_small_x1_0` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x1_0'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_small_x1_0')
return model
......@@ -637,9 +559,7 @@ def MobileNetV3_small_x1_25(pretrained=False, **kwargs):
model = _mobilenet_v3.MobileNetV3_small_x1_25(**kwargs)
if pretrained:
assert 'MobileNetV3_small_x1_25' in _checkpoints, 'Not provide `MobileNetV3_small_x1_25` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['MobileNetV3_small_x1_25'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'MobileNetV3_small_x1_25')
return model
......@@ -652,9 +572,7 @@ def ResNeXt101_32x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt101_32x4d(**kwargs)
if pretrained:
assert 'ResNeXt101_32x4d' in _checkpoints, 'Not provide `ResNeXt101_32x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt101_32x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt101_32x4d')
return model
......@@ -666,9 +584,7 @@ def ResNeXt101_64x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt101_64x4d(**kwargs)
if pretrained:
assert 'ResNeXt101_64x4d' in _checkpoints, 'Not provide `ResNeXt101_64x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt101_64x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt101_64x4d')
return model
......@@ -680,9 +596,7 @@ def ResNeXt152_32x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt152_32x4d(**kwargs)
if pretrained:
assert 'ResNeXt152_32x4d' in _checkpoints, 'Not provide `ResNeXt152_32x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt152_32x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt152_32x4d')
return model
......@@ -694,9 +608,7 @@ def ResNeXt152_64x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt152_64x4d(**kwargs)
if pretrained:
assert 'ResNeXt152_64x4d' in _checkpoints, 'Not provide `ResNeXt152_64x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt152_64x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt152_64x4d')
return model
......@@ -708,9 +620,7 @@ def ResNeXt50_32x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt50_32x4d(**kwargs)
if pretrained:
assert 'ResNeXt50_32x4d' in _checkpoints, 'Not provide `ResNeXt50_32x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt50_32x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt50_32x4d')
return model
......@@ -722,8 +632,6 @@ def ResNeXt50_64x4d(pretrained=False, **kwargs):
model = _resnext.ResNeXt50_64x4d(**kwargs)
if pretrained:
assert 'ResNeXt50_64x4d' in _checkpoints, 'Not provide `ResNeXt50_64x4d` pretrained model.'
path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNeXt50_64x4d'])
model.set_state_dict(paddle.load(path))
model = _load_pretrained_parameters(model, 'ResNeXt50_64x4d')
return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册