hubconf.py 1.8 KB
Newer Older
L
for hub  
lyuwenyu 已提交
1

L
lyuwenyu 已提交
2
dependencies = ['paddle', 'numpy']
L
for hub  
lyuwenyu 已提交
3

L
lyuwenyu 已提交
4
import paddle
L
for hub  
lyuwenyu 已提交
5

L
lyuwenyu 已提交
6
from ppcls.modeling.architectures import resnet as _resnet 
L
for hub  
lyuwenyu 已提交
7 8


L
lyuwenyu 已提交
9 10 11 12
# _checkpoints = {
#     'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams',
#     'ResNet34': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet34_pretrained.pdparams',
# }
L
lyuwenyu 已提交
13

L
lyuwenyu 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def _load_pretrained_urls():
    '''Load pretrained model parameters url from README.md
    '''
    import re
    from collections import OrderedDict

    with open('./README.md', 'r') as f:
        lines = f.readlines()
        lines = [lin for lin in lines if lin.strip().startswith('|') and 'Download link' in lin]
    
    urls = OrderedDict()
    for lin in lines:
        try:
            name = re.findall(r'\|(.*?)\|', lin)[0].strip().replace('<br>', '')
            url = re.findall(r'\((.*?)\)', lin)[-1].strip()
            if name in url:
                urls[name] = url
        except:
            pass

    return urls


_checkpoints = _load_pretrained_urls()
L
lyuwenyu 已提交
38

L
lyuwenyu 已提交
39 40 41

def ResNet18(**kwargs):
    '''ResNet18
L
for hub  
lyuwenyu 已提交
42
    '''
L
lyuwenyu 已提交
43 44
    pretrained = kwargs.pop('pretrained', False)

L
lyuwenyu 已提交
45
    model = _resnet.ResNet18(**kwargs)
L
lyuwenyu 已提交
46
    if pretrained:
L
lyuwenyu 已提交
47
        assert 'ResNet18' in _checkpoints, 'Not provide `ResNet18` pretrained model.'
L
lyuwenyu 已提交
48 49 50
        path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet18'])
        model.set_state_dict(paddle.load(path))

L
for hub  
lyuwenyu 已提交
51 52
    return model

L
lyuwenyu 已提交
53 54 55 56 57


def ResNet34(**kwargs):
    '''ResNet34
    '''
L
lyuwenyu 已提交
58 59 60 61 62 63 64
    pretrained = kwargs.pop('pretrained', False)

    model = _resnet.ResNet34(**kwargs)
    if pretrained:
        assert 'ResNet34' in _checkpoints, 'Not provide `ResNet34` pretrained model.'
        path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet34'])
        model.set_state_dict(paddle.load(path))
L
lyuwenyu 已提交
65

L
lyuwenyu 已提交
66
    return model
L
lyuwenyu 已提交
67 68