From 7ccd4896f6616b0adc952a12ccc86446197f2f7d Mon Sep 17 00:00:00 2001 From: lyuwenyu Date: Thu, 25 Mar 2021 14:38:31 +0800 Subject: [PATCH] add alexnet vgg resnet series --- hubconf.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/hubconf.py b/hubconf.py index 398fffe7..d348bccc 100644 --- a/hubconf.py +++ b/hubconf.py @@ -3,9 +3,9 @@ dependencies = ['paddle', 'numpy'] import paddle +from ppcls.modeling.architectures import alexnet as _alexnet +from ppcls.modeling.architectures import vgg as _vgg from ppcls.modeling.architectures import resnet as _resnet - - # _checkpoints = { # 'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams', # 'ResNet34': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams', @@ -37,6 +37,80 @@ def _load_pretrained_urls(): _checkpoints = _load_pretrained_urls() + +def AlexNet(**kwargs): + '''AlexNet + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _alexnet.AlexNet(**kwargs) + if pretrained: + assert 'AlexNet' in _checkpoints, 'Not provide `AlexNet` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['AlexNet']) + model.set_state_dict(paddle.load(path)) + + return model + + + +def VGG11(**kwargs): + '''VGG11 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _vgg.VGG11(**kwargs) + if pretrained: + assert 'VGG11' in _checkpoints, 'Not provide `VGG11` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG11']) + model.set_state_dict(paddle.load(path)) + + return model + + +def VGG13(**kwargs): + '''VGG13 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _vgg.VGG13(**kwargs) + if pretrained: + assert 'VGG13' in _checkpoints, 'Not provide `VGG13` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG13']) + model.set_state_dict(paddle.load(path)) + + return model + + +def VGG16(**kwargs): + '''VGG16 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _vgg.VGG16(**kwargs) + if pretrained: + assert 'VGG16' in _checkpoints, 'Not provide `VGG16` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG16']) + model.set_state_dict(paddle.load(path)) + + return model + + +def VGG19(**kwargs): + '''VGG19 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _vgg.VGG19(**kwargs) + if pretrained: + assert 'VGG19' in _checkpoints, 'Not provide `VGG19` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['VGG19']) + model.set_state_dict(paddle.load(path)) + + return model + + + + def ResNet18(**kwargs): '''ResNet18 ''' @@ -51,7 +125,6 @@ def ResNet18(**kwargs): return model - def ResNet34(**kwargs): '''ResNet34 ''' @@ -66,3 +139,43 @@ def ResNet34(**kwargs): return model +def ResNet50(**kwargs): + '''ResNet50 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _resnet.ResNet50(**kwargs) + if pretrained: + assert 'ResNet50' in _checkpoints, 'Not provide `ResNet50` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet50']) + model.set_state_dict(paddle.load(path)) + + return model + + +def ResNet101(**kwargs): + '''ResNet101 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _resnet.ResNet101(**kwargs) + if pretrained: + assert 'ResNet101' in _checkpoints, 'Not provide `ResNet101` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet101']) + model.set_state_dict(paddle.load(path)) + + return model + + +def ResNet152(**kwargs): + '''ResNet152 + ''' + pretrained = kwargs.pop('pretrained', False) + + model = _resnet.ResNet152(**kwargs) + if pretrained: + assert 'ResNet152' in _checkpoints, 'Not provide `ResNet152` pretrained model.' + path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet152']) + model.set_state_dict(paddle.load(path)) + + return model -- GitLab