提交 04fc7f54 编写于 作者: W wuzewu

update download command

上级 eb82a970
......@@ -17,17 +17,39 @@ from __future__ import division
from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.module.manager import default_manager
class DownloadCommand(BaseCommand):
def __init__(self):
super(DownloadCommand, self).__init__()
# yapf: disable
self.add_arg('--output_path', str, ".", "path to save the module, default in current directory" )
self.add_arg('--uncompress', bool, False, "uncompress the download package or not" )
# yapf: enable
def help(self):
pass
self.parser.print_help()
def exec(self, argv):
pass
module_name = argv[1]
self.args = self.parser.parse_args(argv[2:])
if not self.args.output_path:
self.args.output_path = "."
utils.check_path(self.args.output_path)
url = default_downloader.get_module_url(module_name)
assert url, "can't found module %s" % module_name
self.print_args()
if self.args.uncompress:
default_downloader.download_file_and_uncompress(
url=url, save_path=self.args.output_path)
else:
default_downloader.download_file(
url=url, save_path=self.args.output_path)
command = DownloadCommand.instance()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
import os
class LocalModuleManager:
def __init__(self, base_path=None):
self.base_path = base_path if base_path else os.path.expanduser('~')
utils.check_path(self.base_path)
self.local_hub_dir = os.path.join(self.base_path, ".hub")
self.local_modules_dir = os.path.join(self.local_hub_dir, "modules")
self.modules = []
if not os.path.exists(self.local_modules_dir):
utils.mkdir(self.local_modules_dir)
elif os.path.isfile(self.local_modules_dir):
#TODO(wuzewu): give wanring
pass
def check_module_valid(self, module_path):
#TODO(wuzewu): code
return True
def all_modules(self, update=False):
if not update and self.modules:
return self.modules
self.modules = []
for sub_dir_name in os.listdir(self.local_modules_dir):
sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
if os.path.isdir(sub_dir_path) and self.check_module_valid(
sub_dir_path):
#TODO(wuzewu): get module name
module_name = sub_dir_path
self.modules.append(module_name)
return self.modules
def search_module(self, module_name, update=False):
self.all_modules(update=update)
return module_name in self.all_modules
def install_module(self, upgrade=False):
pass
def uninstall_module(self):
pass
default_manager = LocalModuleManager()
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
from paddle_hub.tools.logger import logger
import six
import distutils.util
def add_argument(argument, type, default, help, argparser, **kwargs):
......
......@@ -25,6 +25,7 @@ import tempfile
import tarfile
from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.data.reader import csv_reader
__all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader']
......@@ -58,6 +59,7 @@ def md5file(fname):
class Downloader:
def __init__(self, module_home=None):
self.module_home = module_home if module_home else MODULE_HOME
self.module_list_file = []
def download_file(self, url, save_path=None, save_name=None, retry_limit=3):
module_name = url.split("/")[-2]
......@@ -127,5 +129,34 @@ class Downloader:
retry_limit=retry_limit)
return self.uncompress(file, delete_file=delete_file)
def get_module_url(self, module_name, version=None):
if not self.module_list_file:
#TODO(wuzewu): download file in tmp directory
self.module_list_file = self.download_file(
url="https://paddlehub.bj.bcebos.com/module_file_list.csv")
self.module_list_file = csv_reader.read(self.module_list_file)
module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module == module_name
]
module_version_list = [
self.module_list_file['version'][index]
for index in module_index_list
]
#TODO(wuzewu): version sort method
module_version_list = sorted(module_version_list)
if not version:
if not module_version_list:
return None
version = module_version_list[-1]
for index in module_index_list:
if self.module_list_file['version'][index] == version:
return self.module_list_file['url'][index]
return None
default_downloader = Downloader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册