server.py 3.7 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#coding:utf-8
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
W
wuzewu 已提交
17
from typing import List
W
wuzewu 已提交
18 19

from paddlehub.server import ServerSource, GitSource
W
wuzewu 已提交
20
from paddlehub.utils import utils
W
wuzewu 已提交
21 22 23 24 25

PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'


class HubServer(object):
W
wuzewu 已提交
26
    '''PaddleHub server'''
W
wuzewu 已提交
27

W
wuzewu 已提交
28 29 30
    def __init__(self):
        self.sources = OrderedDict()

W
wuzewu 已提交
31
    def _generate_source(self, url: str, source_type: str = 'git'):
W
wuzewu 已提交
32
        if source_type == 'server':
W
wuzewu 已提交
33
            source = ServerSource(url)
W
wuzewu 已提交
34
        elif source_type == 'git':
W
wuzewu 已提交
35 36
            source = GitSource(url)
        else:
W
wuzewu 已提交
37
            raise RuntimeError('Unknown source type {}.'.format(source_type))
W
wuzewu 已提交
38 39
        return source

W
wuzewu 已提交
40 41 42 43
    def _get_source_key(self, url: str):
        return 'source_{}'.format(utils.md5(url))

    def add_source(self, url: str, source_type: str = 'git'):
W
wuzewu 已提交
44
        '''Add a module source(GitSource or ServerSource)'''
W
wuzewu 已提交
45
        key = self._get_source_key(url)
W
wuzewu 已提交
46
        self.sources[key] = self._generate_source(url, source_type)
W
wuzewu 已提交
47

W
wuzewu 已提交
48 49
    def remove_source(self, url: str = None, key: str = None):
        '''Remove a module source'''
W
wuzewu 已提交
50 51
        self.sources.pop(key)

W
wuzewu 已提交
52 53 54 55 56 57 58 59 60 61 62
    def get_source(self, url: str):
        ''''''
        key = self._get_source_key(url)
        return self.sources.get(key, None)

    def search_module(self,
                      name: str,
                      version: str = None,
                      source: str = None,
                      update: bool = False,
                      branch: str = None) -> List[dict]:
W
wuzewu 已提交
63 64 65 66 67 68 69
        '''
        Search PaddleHub module

        Args:
            name(str) : PaddleHub module name
            version(str) : PaddleHub module version
        '''
W
wuzewu 已提交
70 71 72 73 74 75 76 77 78 79
        return self.search_resource(
            type='module', name=name, version=version, source=source, update=update, branch=branch)

    def search_resource(self,
                        type: str,
                        name: str,
                        version: str = None,
                        source: str = None,
                        update: bool = False,
                        branch: str = None) -> List[dict]:
W
wuzewu 已提交
80 81 82 83 84 85 86 87
        '''
        Search PaddleHub Resource

        Args:
            type(str) : Resource type
            name(str) : Resource name
            version(str) : Resource version
        '''
W
wuzewu 已提交
88 89
        sources = self.sources.values() if not source else [self._generate_source(source)]
        for source in sources:
W
wuzewu 已提交
90 91 92 93 94 95
            if isinstance(source, GitSource) and update:
                source.update()

            if isinstance(source, GitSource) and branch:
                source.checkout(branch)

W
wuzewu 已提交
96 97 98
            result = source.search_resource(name=name, type=type, version=version)
            if result:
                return result
W
wuzewu 已提交
99
        return []
W
wuzewu 已提交
100 101 102 103 104 105 106

    def get_module_info(self, name: str, source: str = None) -> dict:
        '''
        '''
        sources = self.sources.values() if not source else [self._generate_source(source)]
        for source in sources:
            result = source.get_module_info(name=name)
W
wuzewu 已提交
107 108
            if result:
                return result
W
wuzewu 已提交
109
        return {}
W
wuzewu 已提交
110 111 112


module_server = HubServer()
W
wuzewu 已提交
113
module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server')