提交 5ddd3a65 编写于 作者: B BinLong

add paddlehub server

上级 d64a1b6d
...@@ -31,6 +31,7 @@ from .common.dir import USER_HOME ...@@ -31,6 +31,7 @@ from .common.dir import USER_HOME
from .common.dir import HUB_HOME from .common.dir import HUB_HOME
from .common.dir import MODULE_HOME from .common.dir import MODULE_HOME
from .common.dir import CACHE_HOME from .common.dir import CACHE_HOME
from .common.dir import CONF_HOME
from .common.logger import logger from .common.logger import logger
from .common.paddle_helper import connect_program from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server from .common.hub_server import default_hub_server
......
...@@ -21,3 +21,4 @@ HUB_HOME = os.path.join(USER_HOME, ".paddlehub") ...@@ -21,3 +21,4 @@ HUB_HOME = os.path.join(USER_HOME, ".paddlehub")
MODULE_HOME = os.path.join(HUB_HOME, "modules") MODULE_HOME = os.path.join(HUB_HOME, "modules")
CACHE_HOME = os.path.join(HUB_HOME, "cache") CACHE_HOME = os.path.join(HUB_HOME, "cache")
DATA_HOME = os.path.join(HUB_HOME, "dataset") DATA_HOME = os.path.join(HUB_HOME, "dataset")
CONF_HOME = os.path.join(HUB_HOME, "conf")
...@@ -20,9 +20,13 @@ from __future__ import print_function ...@@ -20,9 +20,13 @@ from __future__ import print_function
import os import os
import time import time
import re import re
import requests
import json
import yaml
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.common.server_config import default_server_config
from paddlehub.io.parser import yaml_parser from paddlehub.io.parser import yaml_parser
import paddlehub as hub import paddlehub as hub
...@@ -31,11 +35,20 @@ CACHE_TIME = 60 * 10 ...@@ -31,11 +35,20 @@ CACHE_TIME = 60 * 10
class HubServer(object): class HubServer(object):
def __init__(self, server_url=None): def __init__(self, config_file_path=None):
if not server_url: if not config_file_path:
server_url = "https://paddlehub.bj.bcebos.com/" config_file_path = hub.CONF_HOME + '/config.json'
utils.check_url(server_url) if not os.path.exists(hub.CONF_HOME):
self.server_url = server_url utils.mkdir(hub.CONF_HOME)
if not os.path.exists(config_file_path):
with open(config_file_path, 'w+') as fp:
fp.write(json.dumps(default_server_config))
with open(config_file_path) as fp:
self.config = json.load(fp)
utils.check_url(self.config['server_url'])
self.server_url = self.config['server_url']
self._load_resource_list_file_if_valid() self._load_resource_list_file_if_valid()
def resource_list_file_path(self): def resource_list_file_path(self):
...@@ -67,6 +80,18 @@ class HubServer(object): ...@@ -67,6 +80,18 @@ class HubServer(object):
return True return True
def search_resource(self, resource_key, resource_type=None, update=False): def search_resource(self, resource_key, resource_type=None, update=False):
try:
payload = {'word': resource_key}
if resource_type:
payload['type'] = resource_type
r = requests.get(self.server_url + '/' + 'search', params=payload)
r = json.loads(r.text)
if r['status'] == 0 and len(r['data']) > 0:
return [(item['name'], item['type'], item['version'], item['summary'])
for item in r['data']]
except:
pass
if update or not self.resource_list_file: if update or not self.resource_list_file:
self.request() self.request()
...@@ -103,6 +128,19 @@ class HubServer(object): ...@@ -103,6 +128,19 @@ class HubServer(object):
resource_type=None, resource_type=None,
version=None, version=None,
update=False): update=False):
try:
payload = {'word': resource_name}
if resource_type:
payload['type'] = resource_type
if version:
payload['version'] = version
r = requests.get(self.server_url + '/' + 'search', params=payload)
r = json.loads(r.text)
if r['status'] == 0 and len(r['data']) > 0:
return r['data'][0]
except:
pass
if update or not self.resource_list_file: if update or not self.resource_list_file:
self.request() self.request()
...@@ -152,7 +190,18 @@ class HubServer(object): ...@@ -152,7 +190,18 @@ class HubServer(object):
update=update) update=update)
def request(self): def request(self):
file_url = self.server_url + RESOURCE_LIST_FILE if not os.path.exists(hub.CACHE_HOME):
utils.mkdir(hub.CACHE_HOME)
try:
r = requests.get(self.server_url + '/' + 'search')
data = json.loads(r.text)
with open(hub.CACHE_HOME + '/' + RESOURCE_LIST_FILE, 'w+') as fp:
yaml.safe_dump({'resource_list' : data['data']}, fp)
return True
except:
pass
file_url = self.config['resource_storage_server_url'] + RESOURCE_LIST_FILE
result, tips, self.resource_list_file = default_downloader.download_file( result, tips, self.resource_list_file = default_downloader.download_file(
file_url, save_path=hub.CACHE_HOME) file_url, save_path=hub.CACHE_HOME)
if not result: if not result:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
default_server_config = {
"server_url" : "http://hub.paddlepaddle.org:8888",
"resource_storage_server_url" : "https//bj.bcebos.com/paddlehub"
}
...@@ -23,6 +23,7 @@ import shutil ...@@ -23,6 +23,7 @@ import shutil
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common.downloader import default_downloader from paddlehub.common.downloader import default_downloader
from paddlehub.common.dir import MODULE_HOME from paddlehub.common.dir import MODULE_HOME
from paddlehub.module import module_desc_pb2
import paddlehub as hub import paddlehub as hub
...@@ -38,7 +39,17 @@ class LocalModuleManager(object): ...@@ -38,7 +39,17 @@ class LocalModuleManager(object):
def check_module_valid(self, module_path): def check_module_valid(self, module_path):
#TODO(wuzewu): code #TODO(wuzewu): code
return True info = {}
try:
desc_pb_path = os.path.join(module_path, 'module_desc.pb')
if os.path.exists(desc_pb_path) and os.path.isfile(desc_pb_path):
desc = module_desc_pb2.ModuleDesc()
with open(desc_pb_path, "rb") as fp:
desc.ParseFromString(fp.read())
info['version'] = desc.attr.map.data["module_info"].map.data["version"].s
except:
return False, None
return True, info
def all_modules(self, update=False): def all_modules(self, update=False):
if not update and self.modules_dict: if not update and self.modules_dict:
...@@ -46,32 +57,38 @@ class LocalModuleManager(object): ...@@ -46,32 +57,38 @@ class LocalModuleManager(object):
self.modules_dict = {} self.modules_dict = {}
for sub_dir_name in os.listdir(self.local_modules_dir): for sub_dir_name in os.listdir(self.local_modules_dir):
sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name) sub_dir_path = os.path.join(self.local_modules_dir, sub_dir_name)
if os.path.isdir(sub_dir_path) and self.check_module_valid( if os.path.isdir(sub_dir_path):
sub_dir_path):
#TODO(wuzewu): get module name #TODO(wuzewu): get module name
valid, info = self.check_module_valid(sub_dir_path)
if valid:
module_name = sub_dir_name module_name = sub_dir_name
self.modules_dict[module_name] = sub_dir_path self.modules_dict[module_name] = (sub_dir_path, info['version'])
return self.modules_dict return self.modules_dict
def search_module(self, module_name, update=False): def search_module(self, module_name, module_version=None, update=False):
self.all_modules(update=update) self.all_modules(update=update)
return self.modules_dict.get(module_name, None) return self.modules_dict.get(module_name, None)
def install_module(self, module_name, module_version=None, upgrade=False): def install_module(self, module_name, module_version=None, upgrade=False):
self.all_modules(update=True) self.all_modules(update=True)
if module_name in self.modules_dict: module_info = self.modules_dict.get(module_name, None)
module_dir = self.modules_dict[module_name] if module_info:
tips = "Module %s already installed in %s" % (module_name, if not module_version or module_version == self.modules_dict[module_name][1]:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir) module_dir)
return True, tips, module_dir return True, tips, module_dir
search_result = hub.default_hub_server.get_module_url( search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version) module_name, version=module_version)
url = search_result.get('url', None) url = search_result.get('url', None)
md5_value = search_result.get('md5', None) md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None) installed_module_version = search_result.get('version', None)
#TODO(wuzewu): add compatibility check #TODO(wuzewu): add compatibility check
if not url: if not url or (module_version is not None and
installed_module_version != module_version):
tips = "Can't find module %s" % module_name tips = "Can't find module %s" % module_name
if module_version: if module_version:
tips += " with version %s" % module_version tips += " with version %s" % module_version
...@@ -89,11 +106,12 @@ class LocalModuleManager(object): ...@@ -89,11 +106,12 @@ class LocalModuleManager(object):
delete_file=True, delete_file=True,
print_progress=True) print_progress=True)
if module_dir:
save_path = os.path.join(MODULE_HOME, module_name) save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path):
shutil.rmtree(save_path)
shutil.move(module_dir, save_path) shutil.move(module_dir, save_path)
module_dir = save_path module_dir = save_path
if module_dir:
tips = "Successfully installed %s" % module_name tips = "Successfully installed %s" % module_name
if installed_module_version: if installed_module_version:
tips += "-%s" % installed_module_version tips += "-%s" % installed_module_version
...@@ -101,13 +119,18 @@ class LocalModuleManager(object): ...@@ -101,13 +119,18 @@ class LocalModuleManager(object):
tips = "Download %s-%s failed" % (module_name, module_version) tips = "Download %s-%s failed" % (module_name, module_version)
return False, tips, module_dir return False, tips, module_dir
def uninstall_module(self, module_name): def uninstall_module(self, module_name, module_version=None):
self.all_modules(update=True) self.all_modules(update=True)
if not module_name in self.modules_dict: if not module_name in self.modules_dict:
tips = "%s is not installed" % module_name tips = "%s is not installed" % module_name
return True, tips return True, tips
if module_version and module_version != self.modules_dict[module_name][1]:
tips = "%s-%s is not installed" % (module_name, module_version)
return True, tips
tips = "Successfully uninstalled %s" % module_name tips = "Successfully uninstalled %s" % module_name
module_dir = self.modules_dict[module_name] if module_version:
tips += '-%s' % module_version
module_dir = self.modules_dict[module_name][0]
shutil.rmtree(module_dir) shutil.rmtree(module_dir)
return True, tips return True, tips
......
...@@ -96,7 +96,8 @@ class Module(object): ...@@ -96,7 +96,8 @@ class Module(object):
module_info=None, module_info=None,
assets=None, assets=None,
processor=None, processor=None,
extra_info=None): extra_info=None,
version=None):
self.desc = module_desc_pb2.ModuleDesc() self.desc = module_desc_pb2.ModuleDesc()
self.program = None self.program = None
self.assets = [] self.assets = []
...@@ -118,7 +119,7 @@ class Module(object): ...@@ -118,7 +119,7 @@ class Module(object):
# TODO(wuzewu): print more module loading info log # TODO(wuzewu): print more module loading info log
if name: if name:
self._init_with_name(name=name) self._init_with_name(name=name, version=version)
elif module_dir: elif module_dir:
self._init_with_module_file(module_dir=module_dir) self._init_with_module_file(module_dir=module_dir)
elif signatures: elif signatures:
...@@ -137,10 +138,13 @@ class Module(object): ...@@ -137,10 +138,13 @@ class Module(object):
else: else:
raise ValueError("Module initialized parameter is empty") raise ValueError("Module initialized parameter is empty")
def _init_with_name(self, name): def _init_with_name(self, name, version=None):
logger.info("Installing %s module" % name) log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
result, tips, module_dir = default_module_manager.install_module( result, tips, module_dir = default_module_manager.install_module(
module_name=name) module_name=name, module_version=version)
if not result: if not result:
logger.error(tips) logger.error(tips)
exit(1) exit(1)
......
...@@ -6,3 +6,5 @@ pyyaml ...@@ -6,3 +6,5 @@ pyyaml
numpy >= 1.12.0 numpy >= 1.12.0
Pillow Pillow
six >= 1.10.0 six >= 1.10.0
chardet == 3.0.4
requests
...@@ -36,6 +36,7 @@ REQUIRED_PACKAGES = [ ...@@ -36,6 +36,7 @@ REQUIRED_PACKAGES = [
'protobuf >= 3.1.0', 'protobuf >= 3.1.0',
'pyyaml', 'pyyaml',
'Pillow', 'Pillow',
'requests',
"visualdl >= 1.3.0", "visualdl >= 1.3.0",
] ]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册