提交 b7edb76c 编写于 作者: W wuzewu

resource list add md5 column

上级 abe1183d
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from paddlehub.common.logger import logger from paddlehub.common.logger import logger
from paddlehub.common import utils from paddlehub.common import utils
...@@ -57,8 +58,10 @@ class DownloadCommand(BaseCommand): ...@@ -57,8 +58,10 @@ class DownloadCommand(BaseCommand):
self.args.output_path = "." self.args.output_path = "."
utils.check_path(self.args.output_path) utils.check_path(self.args.output_path)
url = default_hub_server.get_model_url( search_result = default_hub_server.get_model_url(
model_name, version=model_version) model_name, version=model_version)
url = search_result.get('url', None)
except_md5_value = search_result.get('md5', None)
if not url: if not url:
tips = "can't found model %s" % model_name tips = "can't found model %s" % model_name
if model_version: if model_version:
...@@ -66,13 +69,34 @@ class DownloadCommand(BaseCommand): ...@@ -66,13 +69,34 @@ class DownloadCommand(BaseCommand):
print(tips) print(tips)
return True return True
if self.args.uncompress: need_to_download_file = True
result, tips, file = default_downloader.download_file_and_uncompress( file_name = os.path.basename(url)
url=url, save_path=self.args.output_path, print_progress=True) file = os.path.join(self.args.output_path, file_name)
else: if os.path.exists(file):
print("File %s already existed\nWait to check the MD5 value" %
file_name)
file_md5_value = utils.md5_of_file(file)
if except_md5_value == file_md5_value:
print("MD5 check pass.")
need_to_download_file = False
else:
print("MD5 check failed!\nDelete invalid file.")
os.remove(file)
if need_to_download_file:
result, tips, file = default_downloader.download_file( result, tips, file = default_downloader.download_file(
url=url, save_path=self.args.output_path, print_progress=True) url=url, save_path=self.args.output_path, print_progress=True)
print(tips) if not result:
print(tips)
return False
if self.args.uncompress:
result, tips, file = default_downloader.uncompress(
file=file,
dirname=self.args.output_path,
delete_file=False,
print_progress=True)
print(tips)
return True return True
......
...@@ -32,14 +32,18 @@ from paddlehub.common.logger import logger ...@@ -32,14 +32,18 @@ from paddlehub.common.logger import logger
__all__ = ['Downloader'] __all__ = ['Downloader']
FLUSH_INTERVAL = 0.1 FLUSH_INTERVAL = 0.1
lasttime = time.time()
def md5file(fname):
hash_md5 = hashlib.md5() def progress(str, end=False):
f = open(fname, "rb") global lasttime
for chunk in iter(lambda: f.read(4096), b""): if end:
hash_md5.update(chunk) str += "\n"
f.close() lasttime = 0
return hash_md5.hexdigest() if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
class Downloader: class Downloader:
...@@ -85,30 +89,37 @@ class Downloader: ...@@ -85,30 +89,37 @@ class Downloader:
f.write(data) f.write(data)
if print_progress: if print_progress:
done = int(50 * dl / total_length) done = int(50 * dl / total_length)
if time.time() - starttime >= FLUSH_INTERVAL: progress("%s : [%-50s] %.2f%%" %
sys.stdout.write( (save_name, '=' * done,
"\r%s : [%-50s] %.2f%%" % float(dl / total_length * 100)))
(save_name, '=' * done,
float(dl / total_length * 100)))
starttime = time.time()
sys.stdout.flush()
if print_progress: if print_progress:
sys.stdout.write( progress(
"\r%s : [%-50s]%.2f%%\n" % "%s : [%-50s] %.2f%%" % (save_name, '=' * 50, 100),
(save_name, '=' * done, float(dl / total_length * 100))) end=True)
sys.stdout.flush()
tips = "file %s download completed!" % (file_name) tips = "file %s download completed!" % (file_name)
return True, tips, file_name return True, tips, file_name
def uncompress(self, file, dirname=None, delete_file=False): def uncompress(self,
file,
dirname=None,
delete_file=False,
print_progress=False):
dirname = os.path.dirname(file) if dirname is None else dirname dirname = os.path.dirname(file) if dirname is None else dirname
with tarfile.open(file, "r:gz") as tar: with tarfile.open(file, "r:gz") as tar:
file_names = tar.getnames() file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.join(dirname, file_names[0]) module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names: for index, file_name in enumerate(file_names):
if print_progress:
done = int(50 * float(index) / size)
progress("%s : [%-50s] %.2f%%" %
(file, '=' * done, float(index / size * 100)))
tar.extract(file_name, dirname) tar.extract(file_name, dirname)
if print_progress:
progress(
"%s : [%-50s] %.2f%%" % (file, '=' * 50, 100), end=True)
if delete_file: if delete_file:
os.remove(file) os.remove(file)
...@@ -131,7 +142,8 @@ class Downloader: ...@@ -131,7 +142,8 @@ class Downloader:
replace=replace) replace=replace)
if not result: if not result:
return result, tips_1, file return result, tips_1, file
result, tips_2, file = self.uncompress(file, delete_file=delete_file) result, tips_2, file = self.uncompress(
file, delete_file=delete_file, print_progress=print_progress)
if not result: if not result:
return result, tips_2, file return result, tips_2, file
if save_name: if save_name:
......
...@@ -106,7 +106,7 @@ class HubServer: ...@@ -106,7 +106,7 @@ class HubServer:
self.request() self.request()
if not self._load_resource_list_file_if_valid(): if not self._load_resource_list_file_if_valid():
return None return {}
resource_index_list = [ resource_index_list = [
index index
...@@ -123,14 +123,17 @@ class HubServer: ...@@ -123,14 +123,17 @@ class HubServer:
resource_version_list = sorted(resource_version_list) resource_version_list = sorted(resource_version_list)
if not version: if not version:
if not resource_version_list: if not resource_version_list:
return None return {}
version = resource_version_list[-1] version = resource_version_list[-1]
for index in resource_index_list: for index in resource_index_list:
if self.resource_list_file['version'][index] == version: if self.resource_list_file['version'][index] == version:
return self.resource_list_file['url'][index] return {
'url': self.resource_list_file['url'][index],
'md5': self.resource_list_file['md5'][index]
}
return None return {}
def get_module_url(self, module_name, version=None, update=False): def get_module_url(self, module_name, version=None, update=False):
return self.get_resource_url( return self.get_resource_url(
......
...@@ -64,8 +64,10 @@ class LocalModuleManager: ...@@ -64,8 +64,10 @@ class LocalModuleManager:
tips = "Module %s already installed in %s" % (module_name, tips = "Module %s already installed in %s" % (module_name,
module_dir) module_dir)
return True, tips, module_dir return True, tips, module_dir
url = hub.default_hub_server.get_module_url( search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version) module_name, version=module_version)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
#TODO(wuzewu): add compatibility check #TODO(wuzewu): add compatibility check
if not url: if not url:
tips = "Can't find module %s" % module_name tips = "Can't find module %s" % module_name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册