diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index e0ca307efa8404b88f752483b1c7ff574c53e1da..d04200f22214228d41fb975fa0ae3ab9da4b291a 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -20,6 +20,8 @@ from easydict import EasyDict __version__ = '2.0.0a0' from paddlehub.utils import log, parser, utils +from paddlehub.utils.paddlex import download, ResourceNotFoundError +from paddlehub.server.server_source import ServerConnectionError from paddlehub.module import Module # In order to maintain the compatibility of the old version, we put the relevant diff --git a/paddlehub/server/server_source.py b/paddlehub/server/server_source.py index 256c39582a5469bfce8702c40bcee42e795372e4..212c3a6d1614df043fe600db6e9e25c90467c3c0 100644 --- a/paddlehub/server/server_source.py +++ b/paddlehub/server/server_source.py @@ -22,6 +22,15 @@ import paddlehub from paddlehub.utils import utils +class ServerConnectionError(Exception): + def __init__(self, url: str): + self.url = url + + def __str__(self): + tips = 'Can\'t connect to Hub Server: {}'.format(self.url) + return tips + + class ServerSource(object): ''' PaddleHub server source @@ -30,6 +39,7 @@ class ServerSource(object): url(str) : Url of the server timeout(int) : Request timeout ''' + def __init__(self, url: str, timeout: int = 10): self._url = url self._timeout = timeout @@ -71,14 +81,20 @@ class ServerSource(object): payload['environments']['platform_type'] = platform.platform() api = '{}/search'.format(self._url) - result = requests.get(api, payload, timeout=self._timeout) - result = result.json() - - if result['status'] == 0 and len(result['data']) > 0: - for item in result['data']: - if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): - return item - return None + + try: + result = requests.get(api, payload, timeout=self._timeout) + result = result.json() + + if result['status'] == 0 and len(result['data']) > 0: + for item in result['data']: + if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): + return item + else: + print(result) + return None + except requests.exceptions.ConnectionError as e: + raise ServerConnectionError(self._url) @classmethod def check(cls, url: str) -> bool: diff --git a/paddlehub/utils/paddlex.py b/paddlehub/utils/paddlex.py new file mode 100644 index 0000000000000000000000000000000000000000..bdc3467aebd2715ff898ac53313b35a3836c20b6 --- /dev/null +++ b/paddlehub/utils/paddlex.py @@ -0,0 +1,64 @@ +# coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +from paddlehub.server.server import module_server +from paddlehub.utils import log, utils, xarfile + + +class ResourceNotFoundError(Exception): + def __init__(self, name: str, version: str = None): + self.name = name + self.version = version + + def __str__(self): + if not self.version: + tips = 'No resource named {} was found'.format(self.name) + else: + tips = 'No resource named {}-{} was found'.format(self.name, self.version) + return tips + + +def download(name: str, save_path: str, version: str = None): + ''' + ''' + file = os.path.join(save_path, name) + file = os.path.realpath(file) + if os.path.exists(file): + return + + resource = module_server.search_resouce(name=name, version=version, type='Model') + if not resource: + raise ResourceNotFoundError(name, version) + + url = resource['url'] + with utils.generate_tempdir() as _dir: + if not os.path.exists(save_path): + os.makedirs(save_path) + + with log.ProgressBar('Download {}'.format(url)) as _bar: + for savefile, dsize, tsize in utils.download_with_progress(url, _dir): + _bar.update(float(dsize / tsize)) + + if xarfile.is_xarfile(savefile): + with log.ProgressBar('Decompress {}'.format(savefile)) as _bar: + for savefile, usize, tsize in xarfile.unarchive_with_progress(savefile, _dir): + _bar.update(float(usize / tsize)) + + savefile = os.path.join(_dir, savefile.split(os.sep)[0]) + + shutil.move(savefile, file)