提交 30142577 编写于 作者: W wuzewu

Add comments

上级 9ca66900
......@@ -31,7 +31,7 @@ class DownloadCommand:
for _arg in argv:
result = module_server.search_module(_arg)
if result:
url = result['url']
url = result[0]['url']
with log.ProgressBar('Download {}'.format(url)) as bar:
for file, ds, ts in utils.download_with_progress(url):
bar.update(float(ds) / ts)
......
......@@ -27,8 +27,12 @@ from paddlehub.utils import utils, log
class ModuleV1(object):
'''
ModuleV1 is an old version of the PaddleHub Module format, which is no longer in use. In order to maintain
compatibility, users can still load the corresponding Module for prediction. User should call `hub.Module`
to initialize the corresponding object, rather than `ModuleV1`.
'''
# All ModuleV1 in PaddleHub is static graph model
@paddle_utils.run_in_static_mode
def __init__(self, name: str = None, directory: str = None, version: str = None):
if not directory:
......@@ -78,6 +82,8 @@ class ModuleV1(object):
num_param_loaded += 1
var = global_block.vars[name]
# Since the pre-trained model saved by the old version of Paddle cannot restore the corresponding
# parameters, we need to restore them manually.
global_block.create_parameter(
name=name,
shape=var.shape,
......@@ -114,8 +120,7 @@ class ModuleV1(object):
@paddle_utils.run_in_static_mode
def context(self, signature: str = None, for_test: bool = False,
trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
'''
'''
'''Get module context information, including graph structure and graph input and output variables.'''
program = self.program.clone(for_test=for_test)
paddle_utils.remove_feed_fetch_op(program)
......@@ -140,8 +145,7 @@ class ModuleV1(object):
@paddle_utils.run_in_static_mode
def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs):
'''
'''
'''Call the specified signature function for prediction.'''
def _get_reader_and_feeder(data_format, data, place):
def _reader(process_data):
......@@ -178,10 +182,12 @@ class ModuleV1(object):
@classmethod
def get_py_requirements(cls) -> List[str]:
'''Get Module's python package dependency list.'''
return []
@classmethod
def load(cls, directory: str) -> EasyDict:
'''Load the Module object defined in the specified directory.'''
module_info = cls.load_module_info(directory)
# Generate a uuid based on the class information, and dynamically create a new type.
......@@ -202,6 +208,7 @@ class ModuleV1(object):
@classmethod
def load_module_info(cls, directory: str) -> EasyDict:
'''Load the infomation of Module object defined in the specified directory.'''
desc_file = os.path.join(directory, 'module_desc.pb')
desc = module_v1_utils.convert_module_desc(desc_file)
return desc.module_info
......@@ -211,4 +218,8 @@ class ModuleV1(object):
@property
def is_runnable(self):
'''
Whether the Module is runnable, in other words, whether can we execute the Module through the
`hub run` command.
'''
return self.default_signature != None
......@@ -15,7 +15,6 @@
import os
import json
from typing import Any
import yaml
from easydict import EasyDict
......@@ -25,6 +24,9 @@ import paddlehub.env as hubenv
class HubConfig:
'''
PaddleHub configuration management class. Each time the PaddleHub package is loaded, PaddleHub will set the
corresponding functions according to the configuration obtained in HubConfig, such as the log level of printing,
server address and so on. When the configuration is modified, PaddleHub needs to be reloaded to take effect.
'''
def __init__(self):
......@@ -38,6 +40,7 @@ class HubConfig:
...
def _initialize(self):
# Set default configuration values.
self.data = EasyDict()
self.data.server = 'http://paddlepaddle.org.cn/paddlehub'
self.data.log = EasyDict()
......@@ -45,24 +48,30 @@ class HubConfig:
self.data.log.level = 'DEBUG'
def reset(self):
'''Reset configuration to default.'''
self._initialize()
self.flush()
@property
def log_level(self):
'''
The lowest output level of PaddleHub logger. Logs below the specified level will not be displayed. The default
is Debug.
'''
return self.data.log.level
@log_level.setter
def log_level(self, level: str):
from paddlehub.utils import log
if not level in log.log_config.keys():
raise ValueError('Unknown log level {}.'.format(level))
raise ValueError('Unknown log level {}. The valid values are {}'.format(level, list(log.log_config.keys())))
self.data.log.level = level
self.flush()
@property
def log_enable(self):
'''Whether to enable the PaddleHub logger to take effect. The default is True.'''
return self.data.log.enable
@log_enable.setter
......@@ -72,6 +81,7 @@ class HubConfig:
@property
def server(self):
'''PaddleHub Module server url.'''
return self.data.server
@server.setter
......@@ -80,6 +90,7 @@ class HubConfig:
self.flush()
def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file:
# convert EasyDict to dict
cfg = json.loads(json.dumps(self.data))
......
......@@ -12,9 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
This module is used to store environmental variables in PaddleHub.
HUB_HOME --> the root directory for storing PaddleHub related data. Default to ~/.paddlehub. Users can change the
├ default value through the HUB_HOME environment variable.
├── MODULE_HOME --> Store the installed PaddleHub Module.
├── CACHE_HOME --> Store the cached data.
├── DATA_HOME --> Store the automatically downloaded datasets.
├── CONF_HOME --> Store the default configuration files.
├── THIRD_PARTY_HOME --> Store third-party libraries.
├── TMP_HOME --> Store temporary files generated during running, such as intermediate products of installing modules,
├ files in this directory will generally be automatically cleared.
├── SOURCES_HOME --> Store the installed code sources.
└── LOG_HOME --> Store log files generated during operation, including some non-fatal errors. The log will be stored
daily.
'''
import os
import shutil
def _get_user_home():
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Compose(object):
'''
'''
def __init__(self, schedulers):
self.schedulers = schedulers
def __call__(self, global_lr, parameters):
for scheduler in self.schedulers:
global_lr, parameters = scheduler(global_lr, parameters)
return global_lr, parameters
class WarmUpLR(object):
'''
'''
def __init__(self, start_step: int, end_step: int, wtype=None):
self.start_step = start_step
self.end_step = end_step
self._curr_step = 0
def __call__(self, global_lr, parameters):
if self._curr_step >= self.start_step and self._curr_step <= self.end_step:
global_lr *= float(self._curr_step - self.start_step) / (self.end_step - self.start_step)
self._curr_step += 1
return global_lr, parameters
class DecayLR(object):
'''
'''
def __init__(self, start_step: int, end_step: int, wtype=None):
self.start_step = start_step
self.end_step = end_step
self._curr_step = 0
def __call__(self, global_lr, parameters):
if self._curr_step >= self.start_step and self._curr_step <= self.end_step:
global_lr *= float(self.end_step - self._curr_step) / (self.end_step - self.start_step)
self._curr_step += 1
return global_lr, parameters
class SlantedTriangleLR(object):
'''
'''
def __init__(self, global_step: int, warmup_prop: float):
self.global_step = global_step
self.warmup_prop = warmup_prop
dividing_line = int(global_step * warmup_prop)
self.scheduler = Compose([
WarmUpLR(start_step=0, end_step=dividing_line),
DecayLR(start_step=dividing_line, end_step=global_step - 1)
])
def __call__(self, global_lr, parameters):
return self.scheduler(global_lr, parameters)
class GradualUnfreeze(object):
pass
class LayeredLR(object):
pass
class L2SP(object):
'''
'''
def __init__(self, regularization_coeff=1e-3):
self.regularization_coeff = regularization_coeff
def __call__(self, global_lr, parameters):
pass
......@@ -22,7 +22,7 @@ from typing import Any, Callable, Generic, List
import paddle
from visualdl import LogWriter
from paddlehub.utils.log import logger, processing
from paddlehub.utils.log import logger
from paddlehub.utils.utils import Timer
......@@ -163,7 +163,11 @@ class Trainer(object):
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader(
train_dataset, batch_sampler=batch_sampler, num_workers=num_workers, return_list=True)
train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True,
use_buffer_reader=True)
steps_per_epoch = len(batch_sampler)
timer = Timer(steps_per_epoch * epochs)
......@@ -258,7 +262,7 @@ class Trainer(object):
sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int)
with processing('Evaluation on validation dataset'):
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def check_platform(module):
pass
def check_py_requirements(module):
pass
def check_py_requirement(module, package):
pass
def get_py_requirements(module):
pass
def check_hub_requirements(module):
pass
def check_hub_requirement(module, r_module):
pass
def get_hub_requirements(module):
pass
......@@ -217,7 +217,7 @@ class LocalModuleManager(object):
if os.path.exists(module_dir):
try:
module = self._local_modules[name] = HubModule.load(module_dir)
except Exception as e:
except:
utils.record_exception('An error was encountered while loading {}.'.format(name))
if not module:
......@@ -237,7 +237,7 @@ class LocalModuleManager(object):
fulldir = os.path.join(self.home, subdir)
try:
self._local_modules[subdir] = HubModule.load(fulldir)
except Exception as e:
except:
utils.record_exception('An error was encountered while loading {}.'.format(subdir))
return [module for module in self._local_modules.values()]
......@@ -258,7 +258,7 @@ class LocalModuleManager(object):
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name)
module_infos = module_server.get_module_compat_info(name=name)
# The HubModule with the specified name cannot be found
if not module_infos:
raise HubModuleNotFoundError(name=name, version=version)
......
......@@ -16,11 +16,10 @@
import ast
import builtins
import inspect
import importlib
import os
import re
import sys
from typing import Callable, Generic, List, Optional
from typing import Callable, Generic, List, Optional, Union
from easydict import EasyDict
......@@ -43,6 +42,7 @@ _module_runnable_func = {}
def runnable(func: Callable) -> Callable:
'''Mark a Module method as runnable, when the command `hub run` is used, the method will be called.'''
mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_runnable_func[mod] = func.__name__
......@@ -53,6 +53,7 @@ def runnable(func: Callable) -> Callable:
def serving(func: Callable) -> Callable:
'''Mark a Module method as serving method, when the command `hub serving` is used, the method will be called.'''
mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_serving_func[mod] = func.__name__
......@@ -62,8 +63,89 @@ def serving(func: Callable) -> Callable:
return _wrapper
class RunModule(object):
'''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__
if mod in module_func_dict:
_func_name = module_func_dict[mod]
return _func_name
elif current_cls.__bases__:
for base_class in current_cls.__bases__:
return self._get_func_name(base_class, module_func_dict)
else:
return None
# After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
# execution of the static graph model to fail, so compatibility protection is required.
def __getattribute__(self, attr):
_attr = object.__getattribute__(self, attr)
# If the acquired attribute is a built-in property of the object, skip it.
if re.match('__.*__', attr):
return _attr
# If the module is a dynamic graph model, skip it.
elif isinstance(self, paddle.nn.Layer):
return _attr
# If the acquired attribute is not a class method, skip it.
elif not inspect.ismethod(_attr):
return _attr
return paddle_utils.run_in_static_mode(_attr)
@classmethod
def get_py_requirements(cls) -> List[str]:
'''Get Module's python package dependency list.'''
py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read()
@property
def is_runnable(self) -> bool:
'''
Whether the Module is runnable, in other words, whether can we execute the Module through the
`hub run` command.
'''
return self._run_func != None
class Module(object):
'''
In PaddleHub, Module represents an executable module, which usually a pre-trained model that can be used for end-to-end
prediction, such as a face detection model or a lexical analysis model, or a pre-trained model that requires finetuning,
such as BERT/ERNIE. When loading a Module with a specified name, if the Module does not exist locally, PaddleHub will
automatically request the server or the specified Git source to download the resource.
Args:
name(str): Module name.
directory(str|optional): Directory of the module to be loaded, only takes effect when the `name` is not specified.
version(str|optional): The version limit of the module, only takes effect when the `name` is specified. When the local
Module does not meet the specified version conditions, PaddleHub will re-request the server to
download the appropriate Module. Default to None, This means that the local Module will be used.
If the Module does not exist, PaddleHub will download the latest version available from the
server according to the usage environment.
source(str|optional): Url of a git repository. If this parameter is specified, PaddleHub will no longer download the
specified Module from the default server, but will look for it in the specified repository.
Default to None.
update(bool|optional): Whether to update the locally cached git repository, only takes effect when the `source`
is specified. Default to False.
branch(str|optional): The branch of the specified git repository. Default to None.
'''
def __new__(cls,
......@@ -72,11 +154,13 @@ class Module(object):
version: str = None,
source: str = None,
update: bool = False,
branch: str = None,
**kwargs):
if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name:
module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs)
module = cls.init_with_name(
name=name, version=version, source=source, update=update, branch=branch, **kwargs)
elif directory:
module = cls.init_with_directory(directory=directory, **kwargs)
else:
......@@ -86,8 +170,7 @@ class Module(object):
@classmethod
def load(cls, directory: str) -> Generic:
'''
'''
'''Load the Module object defined in the specified directory.'''
if directory.endswith(os.sep):
directory = directory[:-1]
......@@ -122,6 +205,7 @@ class Module(object):
@classmethod
def load_module_info(cls, directory: str) -> EasyDict:
'''Load the infomation of Module object defined in the specified directory.'''
# If is ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file):
......@@ -153,9 +237,8 @@ class Module(object):
source: str = None,
update: bool = False,
branch: str = None,
**kwargs):
'''
'''
**kwargs) -> Union[RunModule, ModuleV1]:
'''Initialize Module according to the specified name.'''
from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch)
......@@ -181,9 +264,8 @@ class Module(object):
return user_module_cls(**kwargs)
@classmethod
def init_with_directory(cls, directory: str, **kwargs):
'''
'''
def init_with_directory(cls, directory: str, **kwargs) -> Union[RunModule, ModuleV1]:
'''Initialize Module according to the specified directory.'''
user_module_cls = cls.load(directory)
# The HubModule in the old version will use the _initialize method to initialize,
......@@ -202,77 +284,6 @@ class Module(object):
user_module_cls.directory = directory
return user_module_cls(**kwargs)
@classmethod
def get_py_requirements(cls):
'''
'''
req_file = os.path.join(cls.directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read().split('\n')
class RunModule(object):
'''
'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__
if mod in module_func_dict:
_func_name = module_func_dict[mod]
return _func_name
elif current_cls.__bases__:
for base_class in current_cls.__bases__:
return self._get_func_name(base_class, module_func_dict)
else:
return None
# After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
# execution of the static graph model to fail, so compatibility protection is required.
def __getattribute__(self, attr):
_attr = object.__getattribute__(self, attr)
# If the acquired attribute is a built-in property of the object, skip it.
if re.match('__.*__', attr):
return _attr
# If the module is a dygraph model, skip it.
elif isinstance(self, paddle.nn.Layer):
return _attr
# If the acquired attribute is not a class method, skip it.
elif not inspect.ismethod(_attr):
return _attr
return paddle_utils.run_in_static_mode(_attr)
@classmethod
def get_py_requirements(cls) -> List[str]:
'''
'''
py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read()
@property
def is_runnable(self) -> bool:
return self._run_func != None
def moduleinfo(name: str,
version: str,
......@@ -282,6 +293,8 @@ def moduleinfo(name: str,
type: str = None,
meta=None) -> Callable:
'''
Mark Module information for a python class, and the class will automatically be extended to inherit HubModule. In other words, python classes
marked with moduleinfo can be loaded through hub.Module.
'''
def _wrapper(cls: Generic) -> Generic:
......
......@@ -55,6 +55,7 @@ class GitSource(object):
self.load_hub_modules()
def checkout(self, branch: str):
'''Checkout the current repo to the specified branch.'''
try:
self.repo.git.checkout(branch)
# reload modules
......@@ -63,11 +64,12 @@ class GitSource(object):
utils.record_exception('An error occurred while checkout {}.'.format(self.path))
def update(self):
'''Update the current repo.'''
try:
self.repo.remote().pull(self.repo.branches[0])
# reload modules
self.load_hub_modules()
except Exception as e:
except:
self.hub_modules = OrderedDict()
utils.record_exception('An error occurred while update {}.'.format(self.path))
......@@ -82,7 +84,7 @@ class GitSource(object):
_item = py_module.__dict__[_item]
if issubclass(_item, RunModule):
self.hub_modules[_item.name] = _item
except Exception as e:
except:
self.hub_modules = OrderedDict()
utils.record_exception('An error occurred while loading {}.'.format(self.path))
......@@ -118,6 +120,10 @@ class GitSource(object):
}]
return None
def get_module_compat_info(self, name: str) -> dict:
'''Get the version compatibility information of the model.'''
return {}
@classmethod
def check(cls, url: str) -> bool:
'''
......
......@@ -51,14 +51,14 @@ class HubServer(object):
self.sources.pop(key)
def get_source(self, url: str):
''''''
'''Get a module source by url'''
key = self.keysmap.get(url)
if not key:
return None
return self.sources.get(key)
def get_source_by_key(self, key: str):
''''''
'''Get a module source by key'''
return self.sources.get(key)
def search_module(self,
......@@ -105,9 +105,8 @@ class HubServer(object):
return result
return []
def get_module_info(self, name: str, source: str = None) -> dict:
'''
'''
def get_module_compat_info(self, name: str, source: str = None) -> dict:
'''Get the version compatibility information of the model.'''
sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources:
result = source.get_module_info(name=name)
......
......@@ -18,7 +18,7 @@ import requests
from typing import List
import paddlehub
from paddlehub.utils import utils, platform
from paddlehub.utils import platform
class ServerConnectionError(Exception):
......@@ -79,9 +79,8 @@ class ServerSource(object):
return result['data']
return None
def get_module_info(self, name: str) -> dict:
'''
'''
def get_module_compat_info(self, name: str) -> dict:
'''Get the version compatibility information of the model.'''
def _convert_version(version: str) -> List:
result = []
......@@ -112,8 +111,7 @@ class ServerSource(object):
return {}
def request(self, path: str, params: dict) -> dict:
'''
'''
'''Request server.'''
api = '{}/{}'.format(self._url, path)
try:
result = requests.get(api, params, timeout=self._timeout)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddlehub.env import DATA_HOME
......
......@@ -113,6 +113,32 @@ class Logger(object):
yield
self.handler.terminator = old_terminator
@contextlib.contextmanager
def processing(self, msg: str, interval: float = 0.1):
'''
Continuously print a progress bar with rotating special effects.
Args:
msg(str): Message to be printed.
interval(float): Rotation interval. Default to 0.1.
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with self.use_terminator('\r'):
self.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class ProgressBar(object):
'''
......@@ -172,28 +198,6 @@ class ProgressBar(object):
sys.stdout.write('\n')
@contextlib.contextmanager
def processing(msg: str, interval: float = 0.1):
'''
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with logger.use_terminator('\r'):
logger.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class FormattedText(object):
'''
Cross-platform formatted string
......
......@@ -34,8 +34,7 @@ class ResourceNotFoundError(Exception):
def download(name: str, save_path: str, version: str = None):
'''
'''
'''The download interface provided to PaddleX for downloading the specified model and resource files.'''
file = os.path.join(save_path, name)
file = os.path.realpath(file)
if os.path.exists(file):
......
......@@ -14,7 +14,6 @@
# limitations under the License.
import codecs
import sys
from typing import List
import yaml
......
......@@ -13,12 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import optparse
import os
import pip
import sys
from pip._internal.utils.misc import get_installed_distributions
from typing import List, Tuple
from paddlehub.utils.utils import Version
from paddlehub.utils.io import discard_oe, typein
......
......@@ -246,16 +246,14 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
def get_platform_default_encoding() -> str:
'''
'''
'''Get the default encoding of the current platform.'''
if utils.platform.is_windows():
return 'gbk'
return 'utf8'
def sys_stdin_encoding() -> str:
'''
'''
'''Get the standary input stream default encoding.'''
encoding = sys.stdin.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
......@@ -266,8 +264,7 @@ def sys_stdin_encoding() -> str:
def sys_stdout_encoding() -> str:
'''
'''
'''Get the standary output stream default encoding.'''
encoding = sys.stdout.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
......@@ -278,15 +275,13 @@ def sys_stdout_encoding() -> str:
def md5(text: str):
'''
'''
'''Calculate the md5 value of the input text.'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
def record(msg: str) -> str:
'''
'''
'''Record the specified text into the PaddleHub log file witch will be automatically stored according to date.'''
logfile = os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
with open(logfile, 'a') as file:
file.write('=' * 50 + '\n')
......@@ -298,8 +293,7 @@ def record(msg: str) -> str:
def record_exception(msg: str) -> str:
'''
'''
'''Record the current exception infomation into the PaddleHub log file witch will be automatically stored according to date.'''
tb = traceback.format_exc()
file = record(tb)
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册