提交 73e06533 编写于 作者: W wuzewu

Update config cmd

上级 76e08faa
......@@ -19,13 +19,14 @@ from easydict import EasyDict
__version__ = '2.0.0a0'
from paddlehub.config import config
from paddlehub.utils import log, parser, utils
from paddlehub.utils.paddlex import download, ResourceNotFoundError
from paddlehub.server.server_source import ServerConnectionError
from paddlehub.module import Module
# In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub/compat package, and mapped some modules referenced
# compatible code in the paddlehub.compat package, and mapped some modules referenced
# in the old version
from paddlehub.compat import paddle_utils
from paddlehub.compat.module.processor import BaseProcessor
......
......@@ -14,24 +14,11 @@
# limitations under the License.
import argparse
import json
import os
import re
import hashlib
import uuid
import time
import ast
import paddlehub.config as hubconf
from paddlehub.env import CONF_HOME
from paddlehub.commands import register
from paddlehub.utils.utils import md5
default_server_config = {
"server_url": ["http://paddlepaddle.org.cn/paddlehub"],
"resource_storage_server_url": "https://bj.bcebos.com/paddlehub-data/",
"debug": False,
"log_level": "DEBUG",
"hub_name": md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
}
@register(name='hub.config', description='Configure PaddleHub.')
......@@ -39,40 +26,7 @@ class ConfigCommand:
@staticmethod
def show_config():
print("The current configuration is shown below.")
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
print(json.dumps(json.load(fp), indent=4))
@staticmethod
def set_server_url(server_url):
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
config = json.load(fp)
re_str = r"^(?:http(s)?:\/\/)?[\w.-]+(?:\.[\w\.-]+)+[\w\-\._~:/?#[\]@!\$&'\*\+,;=.]+$"
if re.match(re_str, server_url) is not None:
config["server_url"] = list([server_url])
ConfigCommand.set_config(config)
else:
print("The format of the input url is invalid.")
@staticmethod
def set_config(config):
with open(os.path.join(CONF_HOME, "config.json"), "w") as fp:
fp.write(json.dumps(config))
print("Set success! The current configuration is shown below.")
print(json.dumps(config, indent=4))
@staticmethod
def set_log_level(level):
level = str(level).upper()
if level not in ["NOLOG", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
print("Allowed values include: " "NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL")
return
with open(os.path.join(CONF_HOME, "config.json"), "r") as fp:
current_config = json.load(fp)
with open(os.path.join(CONF_HOME, "config.json"), "w") as fp:
current_config["log_level"] = level
fp.write(json.dumps(current_config))
print("Set success! The current configuration is shown below.")
print(json.dumps(current_config, indent=4))
print(hubconf)
@staticmethod
def show_help():
......@@ -80,11 +34,13 @@ class ConfigCommand:
str += "\tShow PaddleHub config without any option.\n"
str += "option:\n"
str += "reset\n"
str += "\tReset config as default.\n"
str += "\tReset config as default.\n\n"
str += "server==[URL]\n"
str += "\tSet PaddleHub Server url as [URL].\n"
str += "log==[LEVEL]\n"
str += "\tSet log level as [LEVEL:NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL].\n"
str += "\tSet PaddleHub Server url as [URL].\n\n"
str += "log.level==[LEVEL]\n"
str += "\tSet log level.\n\n"
str += "log.enable==True|False\n"
str += "\tEnable or disable logger in PaddleHub.\n"
print(str)
def execute(self, argv):
......@@ -92,11 +48,14 @@ class ConfigCommand:
ConfigCommand.show_config()
for arg in argv:
if arg == "reset":
ConfigCommand.set_config(default_server_config)
hubconf.reset()
print(hubconf)
elif arg.startswith("server=="):
ConfigCommand.set_server_url(arg.split("==")[1])
elif arg.startswith("log=="):
ConfigCommand.set_log_level(arg.split("==")[1])
hubconf.server_url = arg.split("==")[1]
elif arg.startswith("log.level=="):
hubconf.log_level = arg.split("==")[1]
elif arg.startswith("log.enable=="):
hubconf.log_enable = ast.literal_eval(arg.split("==")[1])
else:
ConfigCommand.show_help()
return True
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
from typing import Any
import yaml
from easydict import EasyDict
import paddlehub.env as hubenv
class HubConfig:
'''
'''
def __init__(self):
self._initialize()
self.file = os.path.join(hubenv.CONF_HOME, 'config.yaml')
with open(self.file, 'r') as file:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self.data.update(cfg)
except:
...
def _initialize(self):
self.data = EasyDict()
self.data.server = 'http://paddlepaddle.org.cn/paddlehub'
self.data.log = EasyDict()
self.data.log.enable = True
self.data.log.level = 'DEBUG'
def reset(self):
self._initialize()
self.flush()
@property
def log_level(self):
return self.data.log.level
@log_level.setter
def log_level(self, level: str):
from paddlehub.utils import log
if not level in log.log_config.keys():
raise ValueError('Unknown log level {}.'.format(level))
self.data.log.level = level
self.flush()
@property
def log_enable(self):
return self.data.log.enable
@log_enable.setter
def log_enable(self, enable: bool):
self.data.log.enable = enable
self.flush()
@property
def server(self):
return self.data.server
@server.setter
def server(self, url: str):
self.data.server = url
self.flush()
def flush(self):
with open(self.file, 'w') as file:
# convert EasyDict to dict
cfg = json.loads(json.dumps(self.data))
yaml.dump(cfg, file)
def __str__(self):
cfg = json.loads(json.dumps(self.data))
return yaml.dump(cfg)
def _load_old_config(config: HubConfig):
# The old version of the configuration file is obsolete, read the configuration value and delete it.
old_cfg_file = os.path.join(hubenv.CONF_HOME, 'config.json')
if os.path.exists(old_cfg_file):
with open(old_cfg_file) as file:
try:
cfg = json.loads(file.read())
config.server = cfg['server_url']
config.log_level = cfg['log_level']
except:
...
os.remove(old_cfg_file)
config = HubConfig()
_load_old_config(config)
......@@ -16,8 +16,6 @@
import os
import shutil
from paddlehub.utils import log
def _get_user_home():
return os.path.expanduser('~')
......@@ -30,7 +28,7 @@ def _get_hub_home():
if os.path.isdir(home_path):
return home_path
else:
log.logger.warning('')
raise RuntimeError('The environment variable HUB_HOME {} is not a directory.'.format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddlehub')
......
......@@ -16,7 +16,6 @@
import os
import shutil
import sys
import traceback
from collections import OrderedDict
from typing import List
......@@ -218,11 +217,7 @@ class LocalModuleManager(object):
try:
module = self._local_modules[name] = HubModule.load(module_dir)
except Exception as e:
msg = traceback.format_exc()
file = utils.record(msg)
log.logger.warning(
'An error was encountered while loading {}. Detailed error information can be found in the {}.'.
format(name, file))
utils.record_exception('An error was encountered while loading {}.'.format(name))
if not module:
return None
......@@ -242,11 +237,7 @@ class LocalModuleManager(object):
try:
self._local_modules[subdir] = HubModule.load(fulldir)
except Exception as e:
msg = traceback.format_exc()
file = utils.record(msg)
log.logger.warning(
'An error was encountered while loading {}. Detailed error information can be found in the {}.'.
format(subdir, file))
utils.record_exception('An error was encountered while loading {}.'.format(subdir))
return [module for module in self._local_modules.values()]
......
......@@ -17,7 +17,6 @@ import inspect
import importlib
import os
import sys
import traceback
from collections import OrderedDict
from typing import List
......@@ -61,11 +60,7 @@ class GitSource(object):
# reload modules
self.load_hub_modules()
except:
msg = traceback.format_exc()
file = utils.record(msg)
log.logger.warning(
'An error occurred while checkout {}. Detailed error information can be found in the {}.'.format(
self.path, file))
utils.record_exception('An error occurred while checkout {}.'.format(self.path))
def update(self):
try:
......@@ -74,11 +69,7 @@ class GitSource(object):
self.load_hub_modules()
except Exception as e:
self.hub_modules = OrderedDict()
msg = traceback.format_exc()
file = utils.record(msg)
log.logger.warning(
'An error occurred while update {}. Detailed error information can be found in the {}.'.format(
self.path, file))
utils.record_exception('An error occurred while update {}.'.format(self.path))
def load_hub_modules(self):
if 'hubconf' in sys.modules:
......@@ -93,11 +84,7 @@ class GitSource(object):
self.hub_modules[_item.name] = _item
except Exception as e:
self.hub_modules = OrderedDict()
msg = traceback.format_exc()
file = utils.record(msg)
log.logger.warning(
'An error occurred while loading {}. Detailed error information can be found in the {}.'.format(
self.path, file))
utils.record_exception('An error occurred while loading {}.'.format(self.path))
sys.path.remove(self.path)
......
......@@ -16,17 +16,17 @@
from collections import OrderedDict
from typing import List
import paddlehub.config as hubconf
from paddlehub.server import ServerSource, GitSource
from paddlehub.utils import utils
PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub'
class HubServer(object):
'''PaddleHub server'''
def __init__(self):
self.sources = OrderedDict()
self.keysmap = OrderedDict()
def _generate_source(self, url: str, source_type: str = 'git'):
if source_type == 'server':
......@@ -34,15 +34,16 @@ class HubServer(object):
elif source_type == 'git':
source = GitSource(url)
else:
raise RuntimeError('Unknown source type {}.'.format(source_type))
raise ValueError('Unknown source type {}.'.format(source_type))
return source
def _get_source_key(self, url: str):
return 'source_{}'.format(utils.md5(url))
def add_source(self, url: str, source_type: str = 'git'):
def add_source(self, url: str, source_type: str = 'git', key: str = ''):
'''Add a module source(GitSource or ServerSource)'''
key = self._get_source_key(url)
key = self._get_source_key(url) if not key else key
self.keysmap[url] = key
self.sources[key] = self._generate_source(url, source_type)
def remove_source(self, url: str = None, key: str = None):
......@@ -51,8 +52,14 @@ class HubServer(object):
def get_source(self, url: str):
''''''
key = self._get_source_key(url)
return self.sources.get(key, None)
key = self.keysmap.get(url)
if not key:
return None
return self.sources.get(key)
def get_source_by_key(self, key: str):
''''''
return self.sources.get(key)
def search_module(self,
name: str,
......@@ -110,4 +117,4 @@ class HubServer(object):
module_server = HubServer()
module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server')
module_server.add_source(hubconf.server, source_type='server', key='default_hub_server')
......@@ -25,6 +25,8 @@ from typing import List
import colorlog
from colorama import Fore
import paddlehub.config as hubconf
log_config = {
'DEBUG': {
'level': 10,
......@@ -83,10 +85,10 @@ class Logger(object):
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logLevel = "DEBUG"
self.logLevel = hubconf.log_level
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
self._is_enable = True
self._is_enable = hubconf.log_enable
def disable(self):
self._is_enable = False
......
......@@ -24,6 +24,7 @@ import requests
import sys
import time
import tempfile
import traceback
import types
from typing import Generator
from urllib.parse import urlparse
......@@ -294,3 +295,11 @@ def record(msg: str) -> str:
file.write(str(msg) + '\n' * 3)
return logfile
def record_exception(msg: str) -> str:
'''
'''
tb = traceback.format_exc()
file = record(tb)
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册