server.py 4.1 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#coding:utf-8
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
W
wuzewu 已提交
17
from typing import List
W
wuzewu 已提交
18

W
wuzewu 已提交
19
import paddlehub.config as hubconf
W
wuzewu 已提交
20
from paddlehub.server import ServerSource, GitSource
W
wuzewu 已提交
21
from paddlehub.utils import utils
W
wuzewu 已提交
22 23 24


class HubServer(object):
W
wuzewu 已提交
25
    '''PaddleHub server'''
W
wuzewu 已提交
26

W
wuzewu 已提交
27 28
    def __init__(self):
        self.sources = OrderedDict()
W
wuzewu 已提交
29
        self.keysmap = OrderedDict()
W
wuzewu 已提交
30

W
wuzewu 已提交
31
    def _generate_source(self, url: str, source_type: str = 'git'):
W
wuzewu 已提交
32
        if source_type == 'server':
W
wuzewu 已提交
33
            source = ServerSource(url)
W
wuzewu 已提交
34
        elif source_type == 'git':
W
wuzewu 已提交
35 36
            source = GitSource(url)
        else:
W
wuzewu 已提交
37
            raise ValueError('Unknown source type {}.'.format(source_type))
W
wuzewu 已提交
38 39
        return source

W
wuzewu 已提交
40 41 42
    def _get_source_key(self, url: str):
        return 'source_{}'.format(utils.md5(url))

W
wuzewu 已提交
43
    def add_source(self, url: str, source_type: str = 'git', key: str = ''):
W
wuzewu 已提交
44
        '''Add a module source(GitSource or ServerSource)'''
W
wuzewu 已提交
45 46
        key = self._get_source_key(url) if not key else key
        self.keysmap[url] = key
W
wuzewu 已提交
47
        self.sources[key] = self._generate_source(url, source_type)
W
wuzewu 已提交
48

W
wuzewu 已提交
49 50
    def remove_source(self, url: str = None, key: str = None):
        '''Remove a module source'''
W
wuzewu 已提交
51 52
        self.sources.pop(key)

W
wuzewu 已提交
53
    def get_source(self, url: str):
W
wuzewu 已提交
54
        '''Get a module source by url'''
W
wuzewu 已提交
55 56 57 58 59 60
        key = self.keysmap.get(url)
        if not key:
            return None
        return self.sources.get(key)

    def get_source_by_key(self, key: str):
W
wuzewu 已提交
61
        '''Get a module source by key'''
W
wuzewu 已提交
62
        return self.sources.get(key)
W
wuzewu 已提交
63 64 65 66 67 68 69

    def search_module(self,
                      name: str,
                      version: str = None,
                      source: str = None,
                      update: bool = False,
                      branch: str = None) -> List[dict]:
W
wuzewu 已提交
70 71 72 73 74 75 76
        '''
        Search PaddleHub module

        Args:
            name(str) : PaddleHub module name
            version(str) : PaddleHub module version
        '''
W
wuzewu 已提交
77 78 79 80 81 82 83 84 85 86
        return self.search_resource(
            type='module', name=name, version=version, source=source, update=update, branch=branch)

    def search_resource(self,
                        type: str,
                        name: str,
                        version: str = None,
                        source: str = None,
                        update: bool = False,
                        branch: str = None) -> List[dict]:
W
wuzewu 已提交
87 88 89 90 91 92 93 94
        '''
        Search PaddleHub Resource

        Args:
            type(str) : Resource type
            name(str) : Resource name
            version(str) : Resource version
        '''
W
wuzewu 已提交
95 96
        sources = self.sources.values() if not source else [self._generate_source(source)]
        for source in sources:
W
wuzewu 已提交
97 98 99 100 101 102
            if isinstance(source, GitSource) and update:
                source.update()

            if isinstance(source, GitSource) and branch:
                source.checkout(branch)

W
wuzewu 已提交
103 104 105
            result = source.search_resource(name=name, type=type, version=version)
            if result:
                return result
W
wuzewu 已提交
106
        return []
W
wuzewu 已提交
107

W
wuzewu 已提交
108 109
    def get_module_compat_info(self, name: str, source: str = None) -> dict:
        '''Get the version compatibility information of the model.'''
W
wuzewu 已提交
110 111
        sources = self.sources.values() if not source else [self._generate_source(source)]
        for source in sources:
W
wuzewu 已提交
112
            result = source.get_module_compat_info(name=name)
W
wuzewu 已提交
113 114
            if result:
                return result
W
wuzewu 已提交
115
        return {}
W
wuzewu 已提交
116 117 118


module_server = HubServer()
W
wuzewu 已提交
119
module_server.add_source(hubconf.server, source_type='server', key='default_hub_server')