hubconf.py 894 字节
Newer Older
L
for hub  
lyuwenyu 已提交
1

L
lyuwenyu 已提交
2
dependencies = ['paddle', 'numpy']
L
for hub  
lyuwenyu 已提交
3

L
lyuwenyu 已提交
4
import paddle
L
for hub  
lyuwenyu 已提交
5

L
lyuwenyu 已提交
6 7 8
from ppcls.modeling.architectures.resnet import ResNet18 as _ResNet18
from ppcls.modeling.architectures.resnet import ResNet34 as _ResNet34
from ppcls.modeling.architectures.resnet import ResNet50 as _ResNet50
L
for hub  
lyuwenyu 已提交
9 10


L
lyuwenyu 已提交
11 12 13 14 15
_checkpoints = {
    'ResNet18': 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet18_pretrained.pdparams'
}


L
lyuwenyu 已提交
16 17 18

def ResNet18(**kwargs):
    '''ResNet18
L
for hub  
lyuwenyu 已提交
19
    '''
L
lyuwenyu 已提交
20 21
    pretrained = kwargs.pop('pretrained', False)

L
lyuwenyu 已提交
22
    model = _ResNet18(**kwargs)
L
lyuwenyu 已提交
23 24 25 26
    if pretrained:
        path = paddle.utils.download.get_weights_path_from_url(_checkpoints['ResNet18'])
        model.set_state_dict(paddle.load(path))

L
for hub  
lyuwenyu 已提交
27 28
    return model

L
lyuwenyu 已提交
29 30 31 32 33


def ResNet34(**kwargs):
    '''ResNet34
    '''
L
update  
lyuwenyu 已提交
34
    model = _ResNet34(**kwargs)
L
lyuwenyu 已提交
35 36 37 38 39 40 41
    return model



def ResNet50(**kwargs):
    '''ResNet50
    '''
L
update  
lyuwenyu 已提交
42
    model = _ResNet50(**kwargs)
L
lyuwenyu 已提交
43
    return model