提交 be093813 编写于 作者: W wuzewu

Add paddlex utils

上级 66081648
......@@ -20,6 +20,8 @@ from easydict import EasyDict
__version__ = '2.0.0a0'
from paddlehub.utils import log, parser, utils
from paddlehub.utils.paddlex import download, ResourceNotFoundError
from paddlehub.server.server_source import ServerConnectionError
from paddlehub.module import Module
# In order to maintain the compatibility of the old version, we put the relevant
......
......@@ -22,6 +22,15 @@ import paddlehub
from paddlehub.utils import utils
class ServerConnectionError(Exception):
def __init__(self, url: str):
self.url = url
def __str__(self):
tips = 'Can\'t connect to Hub Server: {}'.format(self.url)
return tips
class ServerSource(object):
'''
PaddleHub server source
......@@ -30,6 +39,7 @@ class ServerSource(object):
url(str) : Url of the server
timeout(int) : Request timeout
'''
def __init__(self, url: str, timeout: int = 10):
self._url = url
self._timeout = timeout
......@@ -71,14 +81,20 @@ class ServerSource(object):
payload['environments']['platform_type'] = platform.platform()
api = '{}/search'.format(self._url)
result = requests.get(api, payload, timeout=self._timeout)
result = result.json()
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
return None
try:
result = requests.get(api, payload, timeout=self._timeout)
result = result.json()
if result['status'] == 0 and len(result['data']) > 0:
for item in result['data']:
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return item
else:
print(result)
return None
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)
@classmethod
def check(cls, url: str) -> bool:
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from paddlehub.server.server import module_server
from paddlehub.utils import log, utils, xarfile
class ResourceNotFoundError(Exception):
def __init__(self, name: str, version: str = None):
self.name = name
self.version = version
def __str__(self):
if not self.version:
tips = 'No resource named {} was found'.format(self.name)
else:
tips = 'No resource named {}-{} was found'.format(self.name, self.version)
return tips
def download(name: str, save_path: str, version: str = None):
'''
'''
file = os.path.join(save_path, name)
file = os.path.realpath(file)
if os.path.exists(file):
return
resource = module_server.search_resouce(name=name, version=version, type='Model')
if not resource:
raise ResourceNotFoundError(name, version)
url = resource['url']
with utils.generate_tempdir() as _dir:
if not os.path.exists(save_path):
os.makedirs(save_path)
with log.ProgressBar('Download {}'.format(url)) as _bar:
for savefile, dsize, tsize in utils.download_with_progress(url, _dir):
_bar.update(float(dsize / tsize))
if xarfile.is_xarfile(savefile):
with log.ProgressBar('Decompress {}'.format(savefile)) as _bar:
for savefile, usize, tsize in xarfile.unarchive_with_progress(savefile, _dir):
_bar.update(float(usize / tsize))
savefile = os.path.join(_dir, savefile.split(os.sep)[0])
shutil.move(savefile, file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册