提交 5053bec7 编写于 作者: W wuzewu

add download api for paddlex

上级 b9107194
......@@ -38,6 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program
from .common.hub_server import HubServer
from .common.hub_server import server_check
from .common.downloader import download, ResourceNotFoundError, ServerConnectionError
from .module.module import Module
from .module.base_processor import BaseProcessor
......
#coding:utf-8
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# Licensed under the Apache License, Version 2.0 (the 'License'
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -28,6 +28,8 @@ import tarfile
from paddlehub.common import utils
from paddlehub.common.logger import logger
from paddlehub.common import tmp_dir
import paddlehub as hub
__all__ = ['Downloader', 'progress']
FLUSH_INTERVAL = 0.1
......@@ -38,10 +40,10 @@ lasttime = time.time()
def progress(str, end=False):
global lasttime
if end:
str += "\n"
str += '\n'
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
sys.stdout.write('\r%s' % str)
lasttime = time.time()
sys.stdout.flush()
......@@ -67,7 +69,7 @@ class Downloader(object):
if retry_times < retry_limit:
retry_times += 1
else:
tips = "Cannot download {0} within retry limit {1}".format(
tips = 'Cannot download {0} within retry limit {1}'.format(
url, retry_limit)
return False, tips, None
r = requests.get(url, stream=True)
......@@ -82,19 +84,19 @@ class Downloader(object):
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % save_name)
print('Downloading %s' % save_name)
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress(
"[%-50s] %.2f%%" %
'[%-50s] %.2f%%' %
('=' * done, float(dl / total_length * 100)))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
progress('[%-50s] %.2f%%' % ('=' * 50, 100), end=True)
tips = "File %s download completed!" % (file_name)
tips = 'File %s download completed!' % (file_name)
return True, tips, file_name
def uncompress(self,
......@@ -104,24 +106,25 @@ class Downloader(object):
print_progress=False):
dirname = os.path.dirname(file) if dirname is None else dirname
if print_progress:
print("Uncompress %s" % file)
with tarfile.open(file, "r:gz") as tar:
print('Uncompress %s' % file)
with tarfile.open(file, 'r:*') as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.join(dirname, file_names[0])
for index, file_name in enumerate(file_names):
if print_progress:
done = int(50 * float(index) / size)
progress("[%-50s] %.2f%%" % ('=' * done,
progress('[%-50s] %.2f%%' % ('=' * done,
float(index / size * 100)))
tar.extract(file_name, dirname)
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
progress('[%-50s] %.2f%%' % ('=' * 50, 100), end=True)
if delete_file:
os.remove(file)
return True, "File %s uncompress completed!" % file, module_dir
return True, 'File %s uncompress completed!' % file, module_dir
def download_file_and_uncompress(self,
url,
......@@ -147,8 +150,59 @@ class Downloader(object):
if save_name:
save_name = os.path.join(save_path, save_name)
shutil.move(file, save_name)
return result, "%s\n%s" % (tips_1, tips_2), save_name
return result, "%s\n%s" % (tips_1, tips_2), file
return result, '%s\n%s' % (tips_1, tips_2), save_name
return result, '%s\n%s' % (tips_1, tips_2), file
default_downloader = Downloader()
class ResourceNotFoundError(Exception):
def __init__(self, name, version=None):
self.name = name
self.version = version
def __str__(self):
if self.version:
tips = 'No resource named {} was found'.format(self.name)
else:
tips = 'No resource named {}-{} was found'.format(
self.name, self.version)
return tips
class ServerConnectionError(Exception):
def __str__(self):
tips = 'Can\'t connect to Hub Server:{}'.format(
hub.HubServer().server_url[0])
return tips
def download(name,
save_path,
version=None,
decompress=True,
resource_type='Model',
extra=None):
if not hub.HubServer()._server_check():
raise ServerConnectionError
search_result = hub.HubServer().get_resource_url(
name, resource_type=resource_type, version=version, extra=extra)
if not search_result:
raise ResourceNotFoundError(name, version)
url = search_result['url']
file = os.path.join(save_path, name)
file = os.path.realpath(file)
if os.path.exists(file):
return
with tmp_dir() as _dir:
_, _, savefile = default_downloader.download_file(
url=url, save_path=_dir, print_progress=True)
if tarfile.is_tarfile(savefile) and decompress:
_, _, savefile = default_downloader.uncompress(
file=savefile, print_progress=True)
shutil.move(savefile, file)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册