提交 978757a0 编写于 作者: W wuzewu

add hub server

上级 e7fc1be5
......@@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dir import USER_HOME
from .dir import HUB_HOME
from .dir import MODULE_HOME
from .dir import CACHE_HOME
from . import module
from . import tools
from . import data
from .module.module import Module, create_module
from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature
from .module.manager import default_module_manager
from .tools.logger import logger
from .tools.paddle_helper import connect_program
from .data.type import DataType
from .hub_server import default_hub_server
......@@ -19,7 +19,7 @@ from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.module.manager import default_manager
from paddle_hub.hub_server import default_hub_server
class DownloadCommand(BaseCommand):
......@@ -48,7 +48,7 @@ class DownloadCommand(BaseCommand):
self.args.output_path = "."
utils.check_path(self.args.output_path)
url = default_downloader.get_module_url(
url = default_hub_server.get_module_url(
module_name, version=module_version)
if not url:
tips = "can't found module %s" % module_name
......
......@@ -18,7 +18,7 @@ from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.module.manager import default_manager
from paddle_hub.module.manager import default_module_manager
class InstallCommand(BaseCommand):
......@@ -39,7 +39,7 @@ class InstallCommand(BaseCommand):
"==")[1]
module_name = module_name if "==" not in module_name else module_name.split(
"==")[0]
default_manager.install_module(
default_module_manager.install_module(
module_name=module_name, module_version=module_version)
......
......@@ -19,7 +19,7 @@ from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.module.manager import default_manager
from paddle_hub.module.manager import default_module_manager
class ListCommand(BaseCommand):
......@@ -31,7 +31,7 @@ class ListCommand(BaseCommand):
self.description = "List all module install in current environment."
def exec(self, argv):
all_modules = default_manager.all_modules()
all_modules = default_module_manager.all_modules()
list_text = "\n"
list_text += " %-20s\t\t%s\n" % ("ModuleName", "ModulePath")
list_text += " %-20s\t\t%s\n" % ("--", "--")
......
......@@ -18,8 +18,7 @@ from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.module.manager import default_manager
from paddle_hub.hub_server import default_hub_server
class SearchCommand(BaseCommand):
......@@ -32,7 +31,7 @@ class SearchCommand(BaseCommand):
def exec(self, argv):
module_name = argv[0]
module_list = default_downloader.search_module(module_name)
module_list = default_hub_server.search_module(module_name)
text = "\n"
text += " %-20s\t\t%s\n" % ("ModuleName", "ModuleVersion")
text += " %-20s\t\t%s\n" % ("--", "--")
......
......@@ -17,7 +17,7 @@ from __future__ import division
from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.module.manager import default_manager
from paddle_hub.module.manager import default_module_manager
from paddle_hub.module.module import Module
import os
......@@ -34,7 +34,7 @@ class ShowCommand(BaseCommand):
module_name = argv[0]
cwd = os.getcwd()
module_dir = default_manager.search_module(module_name)
module_dir = default_module_manager.search_module(module_name)
module_dir = os.path.join(cwd,
module_name) if not module_dir else module_dir
if not module_dir or not os.path.exists(module_dir):
......
......@@ -18,7 +18,7 @@ from __future__ import print_function
from paddle_hub.tools.logger import logger
from paddle_hub.commands.base_command import BaseCommand
from paddle_hub.tools import utils
from paddle_hub.module.manager import default_manager
from paddle_hub.module.manager import default_module_manager
class UninstallCommand(BaseCommand):
......@@ -31,7 +31,7 @@ class UninstallCommand(BaseCommand):
def exec(self, argv):
module_name = argv[0]
default_manager.uninstall_module(module_name=module_name)
default_module_manager.uninstall_module(module_name=module_name)
command = UninstallCommand.instance()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Store PaddleHub version string """
import os
USER_HOME = os.path.expanduser('~')
HUB_HOME = os.path.join(USER_HOME, ".hub")
MODULE_HOME = os.path.join(HUB_HOME, "modules")
CACHE_HOME = os.path.join(HUB_HOME, "cache")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import default_downloader
from paddle_hub.data.reader import csv_reader
import os
import paddle_hub as hub
MODULE_LIST_FILE = "module_file_list.csv"
class HubServer:
def __init__(self, server_url=None):
if not server_url:
server_url = "https://paddlehub.bj.bcebos.com/"
utils.check_url(server_url)
self.server_url = server_url
self.module_file_list = []
def search_module(self, module_key, update=False):
if update or not self.module_file_list:
self.request()
match_module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module_key in module
]
return [(self.module_list_file['module_name'][index],
self.module_list_file['version'][index])
for index in match_module_index_list]
def get_module_url(self, module_name, version=None, update=False):
if update or not self.module_file_list:
self.request()
module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module == module_name
]
module_version_list = [
self.module_list_file['version'][index]
for index in module_index_list
]
#TODO(wuzewu): version sort method
module_version_list = sorted(module_version_list)
if not version:
if not module_version_list:
return None
version = module_version_list[-1]
for index in module_index_list:
if self.module_list_file['version'][index] == version:
return self.module_list_file['url'][index]
return None
def request(self):
file_url = self.server_url + MODULE_LIST_FILE
self.module_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME)
self.module_list_file = csv_reader.read(self.module_list_file)
return True
default_hub_server = HubServer()
if __name__ == "__main__":
print(default_hub_server.search_module("ssd"))
print(default_hub_server.get_module_url("ssd_mobilenet_pascal"))
......@@ -16,16 +16,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle_hub.tools import utils
from paddle_hub.tools.downloader import MODULE_HOME, default_downloader
from paddle_hub.tools.downloader import default_downloader
import paddle_hub as hub
import os
import shutil
class LocalModuleManager:
def __init__(self, base_path=None):
self.base_path = base_path if base_path else os.path.expanduser('~')
utils.check_path(self.base_path)
self.local_modules_dir = MODULE_HOME
def __init__(self, module_home=None):
self.local_modules_dir = module_home if module_home else hub.MODULE_HOME
self.modules_dict = {}
if not os.path.exists(self.local_modules_dir):
utils.mkdir(self.local_modules_dir)
......@@ -61,7 +60,7 @@ class LocalModuleManager:
module_dir = self.modules_dict[module_name]
print("module %s already install in %s" % (module_name, module_dir))
return
url = default_downloader.get_module_url(
url = hub.default_hub_server.get_module_url(
module_name, version=module_version)
#TODO(wuzewu): add compatibility check
if not url:
......@@ -71,7 +70,8 @@ class LocalModuleManager:
print(tips)
return
default_downloader.download_file_and_uncompress(url=url)
default_downloader.download_file_and_uncompress(
url=url, save_path=hub.MODULE_HOME, save_name=module_name)
def uninstall_module(self, module_name):
self.all_modules(update=True)
......@@ -84,4 +84,4 @@ class LocalModuleManager:
print("Successfully uninstalled %s" % module_name)
default_manager = LocalModuleManager()
default_module_manager = LocalModuleManager()
......@@ -17,6 +17,7 @@ from __future__ import print_function
from __future__ import division
from __future__ import print_function
import shutil
import os
import sys
import hashlib
......@@ -27,26 +28,7 @@ from paddle_hub.tools import utils
from paddle_hub.tools.logger import logger
from paddle_hub.data.reader import csv_reader
__all__ = ['MODULE_HOME', 'downloader', 'md5file', 'Downloader']
# TODO(ZeyuChen) add environment varialble to set MODULE_HOME
MODULE_HOME = os.path.expanduser('~')
MODULE_HOME = os.path.join(MODULE_HOME, ".hub")
MODULE_HOME = os.path.join(MODULE_HOME, "modules")
# When running unit tests, there could be multiple processes that
# trying to create MODULE_HOME directory simultaneously, so we cannot
# use a if condition to check for the existence of the directory;
# instead, we use the filesystem as the synchronization mechanism by
# catching returned errors.
def must_mkdirs(path):
try:
os.makedirs(MODULE_HOME)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
__all__ = ['Downloader']
def md5file(fname):
......@@ -59,13 +41,7 @@ def md5file(fname):
class Downloader:
def __init__(self, module_home=None):
self.module_home = module_home if module_home else MODULE_HOME
self.module_list_file = []
def download_file(self, url, save_path=None, save_name=None, retry_limit=3):
module_name = url.split("/")[-2]
save_path = self.module_home if save_path is None else save_path
def download_file(self, url, save_path, save_name=None, retry_limit=3):
if not os.path.exists(save_path):
utils.mkdir(save_path)
save_name = url.split('/')[-1] if save_name is None else save_name
......@@ -108,7 +84,6 @@ class Downloader:
dirname = os.path.dirname(file) if dirname is None else dirname
with tarfile.open(file, "r:gz") as tar:
file_names = tar.getnames()
logger.info(file_names)
module_dir = os.path.join(dirname, file_names[0])
for file_name in file_names:
tar.extract(file_name, dirname)
......@@ -120,7 +95,7 @@ class Downloader:
def download_file_and_uncompress(self,
url,
save_path=None,
save_path,
save_name=None,
retry_limit=3,
delete_file=True):
......@@ -129,53 +104,12 @@ class Downloader:
save_path=save_path,
save_name=save_name,
retry_limit=retry_limit)
return self.uncompress(file, delete_file=delete_file)
def search_module(self, module_name):
if not self.module_list_file:
#TODO(wuzewu): download file in tmp directory
self.module_list_file = self.download_file(
url="https://paddlehub.bj.bcebos.com/module_file_list.csv")
self.module_list_file = csv_reader.read(self.module_list_file)
match_module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module_name in module
]
return [(self.module_list_file['module_name'][index],
self.module_list_file['version'][index])
for index in match_module_index_list]
def get_module_url(self, module_name, version=None):
if not self.module_list_file:
#TODO(wuzewu): download file in tmp directory
self.module_list_file = self.download_file(
url="https://paddlehub.bj.bcebos.com/module_file_list.csv")
self.module_list_file = csv_reader.read(self.module_list_file)
module_index_list = [
index
for index, module in enumerate(self.module_list_file['module_name'])
if module == module_name
]
module_version_list = [
self.module_list_file['version'][index]
for index in module_index_list
]
#TODO(wuzewu): version sort method
module_version_list = sorted(module_version_list)
if not version:
if not module_version_list:
return None
version = module_version_list[-1]
for index in module_index_list:
if self.module_list_file['version'][index] == version:
return self.module_list_file['url'][index]
return None
file = self.uncompress(file, delete_file=delete_file)
if save_name:
save_name = os.path.join(save_path, save_name)
shutil.move(file, save_name)
return save_name
return file
default_downloader = Downloader()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册