From 73e06533843dd47aacfda288fd6cb64930cb7401 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 28 Oct 2020 10:34:04 +0800 Subject: [PATCH] Update config cmd --- paddlehub/__init__.py | 3 +- paddlehub/commands/config.py | 73 +++++----------------- paddlehub/config.py | 108 +++++++++++++++++++++++++++++++++ paddlehub/env.py | 4 +- paddlehub/module/manager.py | 13 +--- paddlehub/server/git_source.py | 19 +----- paddlehub/server/server.py | 23 ++++--- paddlehub/utils/log.py | 6 +- paddlehub/utils/utils.py | 9 +++ 9 files changed, 160 insertions(+), 98 deletions(-) create mode 100644 paddlehub/config.py diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index d04200f2..29445851 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -19,13 +19,14 @@ from easydict import EasyDict __version__ = '2.0.0a0' +from paddlehub.config import config from paddlehub.utils import log, parser, utils from paddlehub.utils.paddlex import download, ResourceNotFoundError from paddlehub.server.server_source import ServerConnectionError from paddlehub.module import Module # In order to maintain the compatibility of the old version, we put the relevant -# compatible code in the paddlehub/compat package, and mapped some modules referenced +# compatible code in the paddlehub.compat package, and mapped some modules referenced # in the old version from paddlehub.compat import paddle_utils from paddlehub.compat.module.processor import BaseProcessor diff --git a/paddlehub/commands/config.py b/paddlehub/commands/config.py index f97186c1..e6870368 100644 --- a/paddlehub/commands/config.py +++ b/paddlehub/commands/config.py @@ -14,24 +14,11 @@ # limitations under the License. import argparse -import json -import os -import re -import hashlib -import uuid -import time +import ast +import paddlehub.config as hubconf from paddlehub.env import CONF_HOME from paddlehub.commands import register -from paddlehub.utils.utils import md5 - -default_server_config = { - "server_url": ["http://paddlepaddle.org.cn/paddlehub"], - "resource_storage_server_url": "https://bj.bcebos.com/paddlehub-data/", - "debug": False, - "log_level": "DEBUG", - "hub_name": md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time())) -} @register(name='hub.config', description='Configure PaddleHub.') @@ -39,40 +26,7 @@ class ConfigCommand: @staticmethod def show_config(): print("The current configuration is shown below.") - with open(os.path.join(CONF_HOME, "config.json"), "r") as fp: - print(json.dumps(json.load(fp), indent=4)) - - @staticmethod - def set_server_url(server_url): - with open(os.path.join(CONF_HOME, "config.json"), "r") as fp: - config = json.load(fp) - re_str = r"^(?:http(s)?:\/\/)?[\w.-]+(?:\.[\w\.-]+)+[\w\-\._~:/?#[\]@!\$&'\*\+,;=.]+$" - if re.match(re_str, server_url) is not None: - config["server_url"] = list([server_url]) - ConfigCommand.set_config(config) - else: - print("The format of the input url is invalid.") - - @staticmethod - def set_config(config): - with open(os.path.join(CONF_HOME, "config.json"), "w") as fp: - fp.write(json.dumps(config)) - print("Set success! The current configuration is shown below.") - print(json.dumps(config, indent=4)) - - @staticmethod - def set_log_level(level): - level = str(level).upper() - if level not in ["NOLOG", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: - print("Allowed values include: " "NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL") - return - with open(os.path.join(CONF_HOME, "config.json"), "r") as fp: - current_config = json.load(fp) - with open(os.path.join(CONF_HOME, "config.json"), "w") as fp: - current_config["log_level"] = level - fp.write(json.dumps(current_config)) - print("Set success! The current configuration is shown below.") - print(json.dumps(current_config, indent=4)) + print(hubconf) @staticmethod def show_help(): @@ -80,11 +34,13 @@ class ConfigCommand: str += "\tShow PaddleHub config without any option.\n" str += "option:\n" str += "reset\n" - str += "\tReset config as default.\n" + str += "\tReset config as default.\n\n" str += "server==[URL]\n" - str += "\tSet PaddleHub Server url as [URL].\n" - str += "log==[LEVEL]\n" - str += "\tSet log level as [LEVEL:NOLOG, DEBUG, INFO, WARNING, ERROR, CRITICAL].\n" + str += "\tSet PaddleHub Server url as [URL].\n\n" + str += "log.level==[LEVEL]\n" + str += "\tSet log level.\n\n" + str += "log.enable==True|False\n" + str += "\tEnable or disable logger in PaddleHub.\n" print(str) def execute(self, argv): @@ -92,11 +48,14 @@ class ConfigCommand: ConfigCommand.show_config() for arg in argv: if arg == "reset": - ConfigCommand.set_config(default_server_config) + hubconf.reset() + print(hubconf) elif arg.startswith("server=="): - ConfigCommand.set_server_url(arg.split("==")[1]) - elif arg.startswith("log=="): - ConfigCommand.set_log_level(arg.split("==")[1]) + hubconf.server_url = arg.split("==")[1] + elif arg.startswith("log.level=="): + hubconf.log_level = arg.split("==")[1] + elif arg.startswith("log.enable=="): + hubconf.log_enable = ast.literal_eval(arg.split("==")[1]) else: ConfigCommand.show_help() return True diff --git a/paddlehub/config.py b/paddlehub/config.py new file mode 100644 index 00000000..950f9517 --- /dev/null +++ b/paddlehub/config.py @@ -0,0 +1,108 @@ +# coding:utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from typing import Any + +import yaml +from easydict import EasyDict + +import paddlehub.env as hubenv + + +class HubConfig: + ''' + ''' + + def __init__(self): + self._initialize() + self.file = os.path.join(hubenv.CONF_HOME, 'config.yaml') + with open(self.file, 'r') as file: + try: + cfg = yaml.load(file, Loader=yaml.FullLoader) + self.data.update(cfg) + except: + ... + + def _initialize(self): + self.data = EasyDict() + self.data.server = 'http://paddlepaddle.org.cn/paddlehub' + self.data.log = EasyDict() + self.data.log.enable = True + self.data.log.level = 'DEBUG' + + def reset(self): + self._initialize() + self.flush() + + @property + def log_level(self): + return self.data.log.level + + @log_level.setter + def log_level(self, level: str): + from paddlehub.utils import log + if not level in log.log_config.keys(): + raise ValueError('Unknown log level {}.'.format(level)) + + self.data.log.level = level + self.flush() + + @property + def log_enable(self): + return self.data.log.enable + + @log_enable.setter + def log_enable(self, enable: bool): + self.data.log.enable = enable + self.flush() + + @property + def server(self): + return self.data.server + + @server.setter + def server(self, url: str): + self.data.server = url + self.flush() + + def flush(self): + with open(self.file, 'w') as file: + # convert EasyDict to dict + cfg = json.loads(json.dumps(self.data)) + yaml.dump(cfg, file) + + def __str__(self): + cfg = json.loads(json.dumps(self.data)) + return yaml.dump(cfg) + + +def _load_old_config(config: HubConfig): + # The old version of the configuration file is obsolete, read the configuration value and delete it. + old_cfg_file = os.path.join(hubenv.CONF_HOME, 'config.json') + if os.path.exists(old_cfg_file): + with open(old_cfg_file) as file: + try: + cfg = json.loads(file.read()) + config.server = cfg['server_url'] + config.log_level = cfg['log_level'] + except: + ... + os.remove(old_cfg_file) + + +config = HubConfig() +_load_old_config(config) diff --git a/paddlehub/env.py b/paddlehub/env.py index aa8e838d..30cc499a 100644 --- a/paddlehub/env.py +++ b/paddlehub/env.py @@ -16,8 +16,6 @@ import os import shutil -from paddlehub.utils import log - def _get_user_home(): return os.path.expanduser('~') @@ -30,7 +28,7 @@ def _get_hub_home(): if os.path.isdir(home_path): return home_path else: - log.logger.warning('') + raise RuntimeError('The environment variable HUB_HOME {} is not a directory.'.format(home_path)) else: return home_path return os.path.join(_get_user_home(), '.paddlehub') diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index 605ebee0..562b1f40 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -16,7 +16,6 @@ import os import shutil import sys -import traceback from collections import OrderedDict from typing import List @@ -218,11 +217,7 @@ class LocalModuleManager(object): try: module = self._local_modules[name] = HubModule.load(module_dir) except Exception as e: - msg = traceback.format_exc() - file = utils.record(msg) - log.logger.warning( - 'An error was encountered while loading {}. Detailed error information can be found in the {}.'. - format(name, file)) + utils.record_exception('An error was encountered while loading {}.'.format(name)) if not module: return None @@ -242,11 +237,7 @@ class LocalModuleManager(object): try: self._local_modules[subdir] = HubModule.load(fulldir) except Exception as e: - msg = traceback.format_exc() - file = utils.record(msg) - log.logger.warning( - 'An error was encountered while loading {}. Detailed error information can be found in the {}.'. - format(subdir, file)) + utils.record_exception('An error was encountered while loading {}.'.format(subdir)) return [module for module in self._local_modules.values()] diff --git a/paddlehub/server/git_source.py b/paddlehub/server/git_source.py index bc416647..67425187 100644 --- a/paddlehub/server/git_source.py +++ b/paddlehub/server/git_source.py @@ -17,7 +17,6 @@ import inspect import importlib import os import sys -import traceback from collections import OrderedDict from typing import List @@ -61,11 +60,7 @@ class GitSource(object): # reload modules self.load_hub_modules() except: - msg = traceback.format_exc() - file = utils.record(msg) - log.logger.warning( - 'An error occurred while checkout {}. Detailed error information can be found in the {}.'.format( - self.path, file)) + utils.record_exception('An error occurred while checkout {}.'.format(self.path)) def update(self): try: @@ -74,11 +69,7 @@ class GitSource(object): self.load_hub_modules() except Exception as e: self.hub_modules = OrderedDict() - msg = traceback.format_exc() - file = utils.record(msg) - log.logger.warning( - 'An error occurred while update {}. Detailed error information can be found in the {}.'.format( - self.path, file)) + utils.record_exception('An error occurred while update {}.'.format(self.path)) def load_hub_modules(self): if 'hubconf' in sys.modules: @@ -93,11 +84,7 @@ class GitSource(object): self.hub_modules[_item.name] = _item except Exception as e: self.hub_modules = OrderedDict() - msg = traceback.format_exc() - file = utils.record(msg) - log.logger.warning( - 'An error occurred while loading {}. Detailed error information can be found in the {}.'.format( - self.path, file)) + utils.record_exception('An error occurred while loading {}.'.format(self.path)) sys.path.remove(self.path) diff --git a/paddlehub/server/server.py b/paddlehub/server/server.py index ee228d68..56c4fce5 100644 --- a/paddlehub/server/server.py +++ b/paddlehub/server/server.py @@ -16,17 +16,17 @@ from collections import OrderedDict from typing import List +import paddlehub.config as hubconf from paddlehub.server import ServerSource, GitSource from paddlehub.utils import utils -PADDLEHUB_PUBLIC_SERVER = 'http://paddlepaddle.org.cn/paddlehub' - class HubServer(object): '''PaddleHub server''' def __init__(self): self.sources = OrderedDict() + self.keysmap = OrderedDict() def _generate_source(self, url: str, source_type: str = 'git'): if source_type == 'server': @@ -34,15 +34,16 @@ class HubServer(object): elif source_type == 'git': source = GitSource(url) else: - raise RuntimeError('Unknown source type {}.'.format(source_type)) + raise ValueError('Unknown source type {}.'.format(source_type)) return source def _get_source_key(self, url: str): return 'source_{}'.format(utils.md5(url)) - def add_source(self, url: str, source_type: str = 'git'): + def add_source(self, url: str, source_type: str = 'git', key: str = ''): '''Add a module source(GitSource or ServerSource)''' - key = self._get_source_key(url) + key = self._get_source_key(url) if not key else key + self.keysmap[url] = key self.sources[key] = self._generate_source(url, source_type) def remove_source(self, url: str = None, key: str = None): @@ -51,8 +52,14 @@ class HubServer(object): def get_source(self, url: str): '''''' - key = self._get_source_key(url) - return self.sources.get(key, None) + key = self.keysmap.get(url) + if not key: + return None + return self.sources.get(key) + + def get_source_by_key(self, key: str): + '''''' + return self.sources.get(key) def search_module(self, name: str, @@ -110,4 +117,4 @@ class HubServer(object): module_server = HubServer() -module_server.add_source(PADDLEHUB_PUBLIC_SERVER, source_type='server') +module_server.add_source(hubconf.server, source_type='server', key='default_hub_server') diff --git a/paddlehub/utils/log.py b/paddlehub/utils/log.py index d91b0f04..5615ee47 100644 --- a/paddlehub/utils/log.py +++ b/paddlehub/utils/log.py @@ -25,6 +25,8 @@ from typing import List import colorlog from colorama import Fore +import paddlehub.config as hubconf + log_config = { 'DEBUG': { 'level': 10, @@ -83,10 +85,10 @@ class Logger(object): self.handler.setFormatter(self.format) self.logger.addHandler(self.handler) - self.logLevel = "DEBUG" + self.logLevel = hubconf.log_level self.logger.setLevel(logging.DEBUG) self.logger.propagate = False - self._is_enable = True + self._is_enable = hubconf.log_enable def disable(self): self._is_enable = False diff --git a/paddlehub/utils/utils.py b/paddlehub/utils/utils.py index ddf295e1..571c88b8 100644 --- a/paddlehub/utils/utils.py +++ b/paddlehub/utils/utils.py @@ -24,6 +24,7 @@ import requests import sys import time import tempfile +import traceback import types from typing import Generator from urllib.parse import urlparse @@ -294,3 +295,11 @@ def record(msg: str) -> str: file.write(str(msg) + '\n' * 3) return logfile + + +def record_exception(msg: str) -> str: + ''' + ''' + tb = traceback.format_exc() + file = record(tb) + utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file)) -- GitLab