提交 5053bec7 编写于 作者: W wuzewu

add download api for paddlex

上级 b9107194
...@@ -38,6 +38,7 @@ from .common.logger import logger ...@@ -38,6 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program from .common.paddle_helper import connect_program
from .common.hub_server import HubServer from .common.hub_server import HubServer
from .common.hub_server import server_check from .common.hub_server import server_check
from .common.downloader import download, ResourceNotFoundError, ServerConnectionError
from .module.module import Module from .module.module import Module
from .module.base_processor import BaseProcessor from .module.base_processor import BaseProcessor
......
#coding:utf-8 # coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the 'License'
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
...@@ -28,6 +28,8 @@ import tarfile ...@@ -28,6 +28,8 @@ import tarfile
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common import tmp_dir
import paddlehub as hub
__all__ = ['Downloader', 'progress'] __all__ = ['Downloader', 'progress']
FLUSH_INTERVAL = 0.1 FLUSH_INTERVAL = 0.1
...@@ -38,10 +40,10 @@ lasttime = time.time() ...@@ -38,10 +40,10 @@ lasttime = time.time()
def progress(str, end=False): def progress(str, end=False):
global lasttime global lasttime
if end: if end:
str += "\n" str += '\n'
lasttime = 0 lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL: if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str) sys.stdout.write('\r%s' % str)
lasttime = time.time() lasttime = time.time()
sys.stdout.flush() sys.stdout.flush()
...@@ -67,7 +69,7 @@ class Downloader(object): ...@@ -67,7 +69,7 @@ class Downloader(object):
if retry_times < retry_limit: if retry_times < retry_limit:
retry_times += 1 retry_times += 1
else: else:
tips = "Cannot download {0} within retry limit {1}".format( tips = 'Cannot download {0} within retry limit {1}'.format(
url, retry_limit) url, retry_limit)
return False, tips, None return False, tips, None
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
...@@ -82,19 +84,19 @@ class Downloader(object): ...@@ -82,19 +84,19 @@ class Downloader(object):
total_length = int(total_length) total_length = int(total_length)
starttime = time.time() starttime = time.time()
if print_progress: if print_progress:
print("Downloading %s" % save_name) print('Downloading %s' % save_name)
for data in r.iter_content(chunk_size=4096): for data in r.iter_content(chunk_size=4096):
dl += len(data) dl += len(data)
f.write(data) f.write(data)
if print_progress: if print_progress:
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
progress( progress(
"[%-50s] %.2f%%" % '[%-50s] %.2f%%' %
('=' * done, float(dl / total_length * 100))) ('=' * done, float(dl / total_length * 100)))
if print_progress: if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) progress('[%-50s] %.2f%%' % ('=' * 50, 100), end=True)
tips = "File %s download completed!" % (file_name) tips = 'File %s download completed!' % (file_name)
return True, tips, file_name return True, tips, file_name
def uncompress(self, def uncompress(self,
...@@ -104,24 +106,25 @@ class Downloader(object): ...@@ -104,24 +106,25 @@ class Downloader(object):
print_progress=False): print_progress=False):
dirname = os.path.dirname(file) if dirname is None else dirname dirname = os.path.dirname(file) if dirname is None else dirname
if print_progress: if print_progress:
print("Uncompress %s" % file) print('Uncompress %s' % file)
with tarfile.open(file, "r:gz") as tar:
with tarfile.open(file, 'r:*') as tar:
file_names = tar.getnames() file_names = tar.getnames()
size = len(file_names) - 1 size = len(file_names) - 1
module_dir = os.path.join(dirname, file_names[0]) module_dir = os.path.join(dirname, file_names[0])
for index, file_name in enumerate(file_names): for index, file_name in enumerate(file_names):
if print_progress: if print_progress:
done = int(50 * float(index) / size) done = int(50 * float(index) / size)
progress("[%-50s] %.2f%%" % ('=' * done, progress('[%-50s] %.2f%%' % ('=' * done,
float(index / size * 100))) float(index / size * 100)))
tar.extract(file_name, dirname) tar.extract(file_name, dirname)
if print_progress: if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) progress('[%-50s] %.2f%%' % ('=' * 50, 100), end=True)
if delete_file: if delete_file:
os.remove(file) os.remove(file)
return True, "File %s uncompress completed!" % file, module_dir return True, 'File %s uncompress completed!' % file, module_dir
def download_file_and_uncompress(self, def download_file_and_uncompress(self,
url, url,
...@@ -147,8 +150,59 @@ class Downloader(object): ...@@ -147,8 +150,59 @@ class Downloader(object):
if save_name: if save_name:
save_name = os.path.join(save_path, save_name) save_name = os.path.join(save_path, save_name)
shutil.move(file, save_name) shutil.move(file, save_name)
return result, "%s\n%s" % (tips_1, tips_2), save_name return result, '%s\n%s' % (tips_1, tips_2), save_name
return result, "%s\n%s" % (tips_1, tips_2), file return result, '%s\n%s' % (tips_1, tips_2), file
default_downloader = Downloader() default_downloader = Downloader()
class ResourceNotFoundError(Exception):
def __init__(self, name, version=None):
self.name = name
self.version = version
def __str__(self):
if self.version:
tips = 'No resource named {} was found'.format(self.name)
else:
tips = 'No resource named {}-{} was found'.format(
self.name, self.version)
return tips
class ServerConnectionError(Exception):
def __str__(self):
tips = 'Can\'t connect to Hub Server:{}'.format(
hub.HubServer().server_url[0])
return tips
def download(name,
save_path,
version=None,
decompress=True,
resource_type='Model',
extra=None):
if not hub.HubServer()._server_check():
raise ServerConnectionError
search_result = hub.HubServer().get_resource_url(
name, resource_type=resource_type, version=version, extra=extra)
if not search_result:
raise ResourceNotFoundError(name, version)
url = search_result['url']
file = os.path.join(save_path, name)
file = os.path.realpath(file)
if os.path.exists(file):
return
with tmp_dir() as _dir:
_, _, savefile = default_downloader.download_file(
url=url, save_path=_dir, print_progress=True)
if tarfile.is_tarfile(savefile) and decompress:
_, _, savefile = default_downloader.uncompress(
file=savefile, print_progress=True)
shutil.move(savefile, file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册