提交 30142577 编写于 作者: W wuzewu

Add comments

上级 9ca66900
...@@ -31,7 +31,7 @@ class DownloadCommand: ...@@ -31,7 +31,7 @@ class DownloadCommand:
for _arg in argv: for _arg in argv:
result = module_server.search_module(_arg) result = module_server.search_module(_arg)
if result: if result:
url = result['url'] url = result[0]['url']
with log.ProgressBar('Download {}'.format(url)) as bar: with log.ProgressBar('Download {}'.format(url)) as bar:
for file, ds, ts in utils.download_with_progress(url): for file, ds, ts in utils.download_with_progress(url):
bar.update(float(ds) / ts) bar.update(float(ds) / ts)
......
...@@ -27,8 +27,12 @@ from paddlehub.utils import utils, log ...@@ -27,8 +27,12 @@ from paddlehub.utils import utils, log
class ModuleV1(object): class ModuleV1(object):
''' '''
ModuleV1 is an old version of the PaddleHub Module format, which is no longer in use. In order to maintain
compatibility, users can still load the corresponding Module for prediction. User should call `hub.Module`
to initialize the corresponding object, rather than `ModuleV1`.
''' '''
# All ModuleV1 in PaddleHub is static graph model
@paddle_utils.run_in_static_mode @paddle_utils.run_in_static_mode
def __init__(self, name: str = None, directory: str = None, version: str = None): def __init__(self, name: str = None, directory: str = None, version: str = None):
if not directory: if not directory:
...@@ -78,6 +82,8 @@ class ModuleV1(object): ...@@ -78,6 +82,8 @@ class ModuleV1(object):
num_param_loaded += 1 num_param_loaded += 1
var = global_block.vars[name] var = global_block.vars[name]
# Since the pre-trained model saved by the old version of Paddle cannot restore the corresponding
# parameters, we need to restore them manually.
global_block.create_parameter( global_block.create_parameter(
name=name, name=name,
shape=var.shape, shape=var.shape,
...@@ -114,8 +120,7 @@ class ModuleV1(object): ...@@ -114,8 +120,7 @@ class ModuleV1(object):
@paddle_utils.run_in_static_mode @paddle_utils.run_in_static_mode
def context(self, signature: str = None, for_test: bool = False, def context(self, signature: str = None, for_test: bool = False,
trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]: trainable: bool = True) -> Tuple[dict, dict, paddle.static.Program]:
''' '''Get module context information, including graph structure and graph input and output variables.'''
'''
program = self.program.clone(for_test=for_test) program = self.program.clone(for_test=for_test)
paddle_utils.remove_feed_fetch_op(program) paddle_utils.remove_feed_fetch_op(program)
...@@ -140,8 +145,7 @@ class ModuleV1(object): ...@@ -140,8 +145,7 @@ class ModuleV1(object):
@paddle_utils.run_in_static_mode @paddle_utils.run_in_static_mode
def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs): def __call__(self, sign_name: str, data: dict, use_gpu: bool = False, batch_size: int = 1, **kwargs):
''' '''Call the specified signature function for prediction.'''
'''
def _get_reader_and_feeder(data_format, data, place): def _get_reader_and_feeder(data_format, data, place):
def _reader(process_data): def _reader(process_data):
...@@ -178,10 +182,12 @@ class ModuleV1(object): ...@@ -178,10 +182,12 @@ class ModuleV1(object):
@classmethod @classmethod
def get_py_requirements(cls) -> List[str]: def get_py_requirements(cls) -> List[str]:
'''Get Module's python package dependency list.'''
return [] return []
@classmethod @classmethod
def load(cls, directory: str) -> EasyDict: def load(cls, directory: str) -> EasyDict:
'''Load the Module object defined in the specified directory.'''
module_info = cls.load_module_info(directory) module_info = cls.load_module_info(directory)
# Generate a uuid based on the class information, and dynamically create a new type. # Generate a uuid based on the class information, and dynamically create a new type.
...@@ -202,6 +208,7 @@ class ModuleV1(object): ...@@ -202,6 +208,7 @@ class ModuleV1(object):
@classmethod @classmethod
def load_module_info(cls, directory: str) -> EasyDict: def load_module_info(cls, directory: str) -> EasyDict:
'''Load the infomation of Module object defined in the specified directory.'''
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
desc = module_v1_utils.convert_module_desc(desc_file) desc = module_v1_utils.convert_module_desc(desc_file)
return desc.module_info return desc.module_info
...@@ -211,4 +218,8 @@ class ModuleV1(object): ...@@ -211,4 +218,8 @@ class ModuleV1(object):
@property @property
def is_runnable(self): def is_runnable(self):
'''
Whether the Module is runnable, in other words, whether can we execute the Module through the
`hub run` command.
'''
return self.default_signature != None return self.default_signature != None
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import os import os
import json import json
from typing import Any
import yaml import yaml
from easydict import EasyDict from easydict import EasyDict
...@@ -25,6 +24,9 @@ import paddlehub.env as hubenv ...@@ -25,6 +24,9 @@ import paddlehub.env as hubenv
class HubConfig: class HubConfig:
''' '''
PaddleHub configuration management class. Each time the PaddleHub package is loaded, PaddleHub will set the
corresponding functions according to the configuration obtained in HubConfig, such as the log level of printing,
server address and so on. When the configuration is modified, PaddleHub needs to be reloaded to take effect.
''' '''
def __init__(self): def __init__(self):
...@@ -38,6 +40,7 @@ class HubConfig: ...@@ -38,6 +40,7 @@ class HubConfig:
... ...
def _initialize(self): def _initialize(self):
# Set default configuration values.
self.data = EasyDict() self.data = EasyDict()
self.data.server = 'http://paddlepaddle.org.cn/paddlehub' self.data.server = 'http://paddlepaddle.org.cn/paddlehub'
self.data.log = EasyDict() self.data.log = EasyDict()
...@@ -45,24 +48,30 @@ class HubConfig: ...@@ -45,24 +48,30 @@ class HubConfig:
self.data.log.level = 'DEBUG' self.data.log.level = 'DEBUG'
def reset(self): def reset(self):
'''Reset configuration to default.'''
self._initialize() self._initialize()
self.flush() self.flush()
@property @property
def log_level(self): def log_level(self):
'''
The lowest output level of PaddleHub logger. Logs below the specified level will not be displayed. The default
is Debug.
'''
return self.data.log.level return self.data.log.level
@log_level.setter @log_level.setter
def log_level(self, level: str): def log_level(self, level: str):
from paddlehub.utils import log from paddlehub.utils import log
if not level in log.log_config.keys(): if not level in log.log_config.keys():
raise ValueError('Unknown log level {}.'.format(level)) raise ValueError('Unknown log level {}. The valid values are {}'.format(level, list(log.log_config.keys())))
self.data.log.level = level self.data.log.level = level
self.flush() self.flush()
@property @property
def log_enable(self): def log_enable(self):
'''Whether to enable the PaddleHub logger to take effect. The default is True.'''
return self.data.log.enable return self.data.log.enable
@log_enable.setter @log_enable.setter
...@@ -72,6 +81,7 @@ class HubConfig: ...@@ -72,6 +81,7 @@ class HubConfig:
@property @property
def server(self): def server(self):
'''PaddleHub Module server url.'''
return self.data.server return self.data.server
@server.setter @server.setter
...@@ -80,6 +90,7 @@ class HubConfig: ...@@ -80,6 +90,7 @@ class HubConfig:
self.flush() self.flush()
def flush(self): def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file: with open(self.file, 'w') as file:
# convert EasyDict to dict # convert EasyDict to dict
cfg = json.loads(json.dumps(self.data)) cfg = json.loads(json.dumps(self.data))
......
...@@ -12,9 +12,25 @@ ...@@ -12,9 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
'''
This module is used to store environmental variables in PaddleHub.
HUB_HOME --> the root directory for storing PaddleHub related data. Default to ~/.paddlehub. Users can change the
├ default value through the HUB_HOME environment variable.
├── MODULE_HOME --> Store the installed PaddleHub Module.
├── CACHE_HOME --> Store the cached data.
├── DATA_HOME --> Store the automatically downloaded datasets.
├── CONF_HOME --> Store the default configuration files.
├── THIRD_PARTY_HOME --> Store third-party libraries.
├── TMP_HOME --> Store temporary files generated during running, such as intermediate products of installing modules,
├ files in this directory will generally be automatically cleared.
├── SOURCES_HOME --> Store the installed code sources.
└── LOG_HOME --> Store log files generated during operation, including some non-fatal errors. The log will be stored
daily.
'''
import os import os
import shutil
def _get_user_home(): def _get_user_home():
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Compose(object):
'''
'''
def __init__(self, schedulers):
self.schedulers = schedulers
def __call__(self, global_lr, parameters):
for scheduler in self.schedulers:
global_lr, parameters = scheduler(global_lr, parameters)
return global_lr, parameters
class WarmUpLR(object):
'''
'''
def __init__(self, start_step: int, end_step: int, wtype=None):
self.start_step = start_step
self.end_step = end_step
self._curr_step = 0
def __call__(self, global_lr, parameters):
if self._curr_step >= self.start_step and self._curr_step <= self.end_step:
global_lr *= float(self._curr_step - self.start_step) / (self.end_step - self.start_step)
self._curr_step += 1
return global_lr, parameters
class DecayLR(object):
'''
'''
def __init__(self, start_step: int, end_step: int, wtype=None):
self.start_step = start_step
self.end_step = end_step
self._curr_step = 0
def __call__(self, global_lr, parameters):
if self._curr_step >= self.start_step and self._curr_step <= self.end_step:
global_lr *= float(self.end_step - self._curr_step) / (self.end_step - self.start_step)
self._curr_step += 1
return global_lr, parameters
class SlantedTriangleLR(object):
'''
'''
def __init__(self, global_step: int, warmup_prop: float):
self.global_step = global_step
self.warmup_prop = warmup_prop
dividing_line = int(global_step * warmup_prop)
self.scheduler = Compose([
WarmUpLR(start_step=0, end_step=dividing_line),
DecayLR(start_step=dividing_line, end_step=global_step - 1)
])
def __call__(self, global_lr, parameters):
return self.scheduler(global_lr, parameters)
class GradualUnfreeze(object):
pass
class LayeredLR(object):
pass
class L2SP(object):
'''
'''
def __init__(self, regularization_coeff=1e-3):
self.regularization_coeff = regularization_coeff
def __call__(self, global_lr, parameters):
pass
...@@ -22,7 +22,7 @@ from typing import Any, Callable, Generic, List ...@@ -22,7 +22,7 @@ from typing import Any, Callable, Generic, List
import paddle import paddle
from visualdl import LogWriter from visualdl import LogWriter
from paddlehub.utils.log import logger, processing from paddlehub.utils.log import logger
from paddlehub.utils.utils import Timer from paddlehub.utils.utils import Timer
...@@ -163,7 +163,11 @@ class Trainer(object): ...@@ -163,7 +163,11 @@ class Trainer(object):
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader( loader = paddle.io.DataLoader(
train_dataset, batch_sampler=batch_sampler, num_workers=num_workers, return_list=True) train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True,
use_buffer_reader=True)
steps_per_epoch = len(batch_sampler) steps_per_epoch = len(batch_sampler)
timer = Timer(steps_per_epoch * epochs) timer = Timer(steps_per_epoch * epochs)
...@@ -258,7 +262,7 @@ class Trainer(object): ...@@ -258,7 +262,7 @@ class Trainer(object):
sum_metrics = defaultdict(int) sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int) avg_metrics = defaultdict(int)
with processing('Evaluation on validation dataset'): with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader): for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx) result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None) loss = result.get('loss', None)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def check_platform(module):
pass
def check_py_requirements(module):
pass
def check_py_requirement(module, package):
pass
def get_py_requirements(module):
pass
def check_hub_requirements(module):
pass
def check_hub_requirement(module, r_module):
pass
def get_hub_requirements(module):
pass
...@@ -217,7 +217,7 @@ class LocalModuleManager(object): ...@@ -217,7 +217,7 @@ class LocalModuleManager(object):
if os.path.exists(module_dir): if os.path.exists(module_dir):
try: try:
module = self._local_modules[name] = HubModule.load(module_dir) module = self._local_modules[name] = HubModule.load(module_dir)
except Exception as e: except:
utils.record_exception('An error was encountered while loading {}.'.format(name)) utils.record_exception('An error was encountered while loading {}.'.format(name))
if not module: if not module:
...@@ -237,7 +237,7 @@ class LocalModuleManager(object): ...@@ -237,7 +237,7 @@ class LocalModuleManager(object):
fulldir = os.path.join(self.home, subdir) fulldir = os.path.join(self.home, subdir)
try: try:
self._local_modules[subdir] = HubModule.load(fulldir) self._local_modules[subdir] = HubModule.load(fulldir)
except Exception as e: except:
utils.record_exception('An error was encountered while loading {}.'.format(subdir)) utils.record_exception('An error was encountered while loading {}.'.format(subdir))
return [module for module in self._local_modules.values()] return [module for module in self._local_modules.values()]
...@@ -258,7 +258,7 @@ class LocalModuleManager(object): ...@@ -258,7 +258,7 @@ class LocalModuleManager(object):
if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version): if name.lower() == item['name'].lower() and utils.Version(item['version']).match(version):
return self._install_from_url(item['url']) return self._install_from_url(item['url'])
module_infos = module_server.get_module_info(name=name) module_infos = module_server.get_module_compat_info(name=name)
# The HubModule with the specified name cannot be found # The HubModule with the specified name cannot be found
if not module_infos: if not module_infos:
raise HubModuleNotFoundError(name=name, version=version) raise HubModuleNotFoundError(name=name, version=version)
......
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
import ast import ast
import builtins import builtins
import inspect import inspect
import importlib
import os import os
import re import re
import sys import sys
from typing import Callable, Generic, List, Optional from typing import Callable, Generic, List, Optional, Union
from easydict import EasyDict from easydict import EasyDict
...@@ -43,6 +42,7 @@ _module_runnable_func = {} ...@@ -43,6 +42,7 @@ _module_runnable_func = {}
def runnable(func: Callable) -> Callable: def runnable(func: Callable) -> Callable:
'''Mark a Module method as runnable, when the command `hub run` is used, the method will be called.'''
mod = func.__module__ + '.' + inspect.stack()[1][3] mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_runnable_func[mod] = func.__name__ _module_runnable_func[mod] = func.__name__
...@@ -53,6 +53,7 @@ def runnable(func: Callable) -> Callable: ...@@ -53,6 +53,7 @@ def runnable(func: Callable) -> Callable:
def serving(func: Callable) -> Callable: def serving(func: Callable) -> Callable:
'''Mark a Module method as serving method, when the command `hub serving` is used, the method will be called.'''
mod = func.__module__ + '.' + inspect.stack()[1][3] mod = func.__module__ + '.' + inspect.stack()[1][3]
_module_serving_func[mod] = func.__name__ _module_serving_func[mod] = func.__name__
...@@ -62,8 +63,89 @@ def serving(func: Callable) -> Callable: ...@@ -62,8 +63,89 @@ def serving(func: Callable) -> Callable:
return _wrapper return _wrapper
class RunModule(object):
'''The base class of PaddleHub Module, users can inherit this class to implement to realize custom class.'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__
if mod in module_func_dict:
_func_name = module_func_dict[mod]
return _func_name
elif current_cls.__bases__:
for base_class in current_cls.__bases__:
return self._get_func_name(base_class, module_func_dict)
else:
return None
# After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
# execution of the static graph model to fail, so compatibility protection is required.
def __getattribute__(self, attr):
_attr = object.__getattribute__(self, attr)
# If the acquired attribute is a built-in property of the object, skip it.
if re.match('__.*__', attr):
return _attr
# If the module is a dynamic graph model, skip it.
elif isinstance(self, paddle.nn.Layer):
return _attr
# If the acquired attribute is not a class method, skip it.
elif not inspect.ismethod(_attr):
return _attr
return paddle_utils.run_in_static_mode(_attr)
@classmethod
def get_py_requirements(cls) -> List[str]:
'''Get Module's python package dependency list.'''
py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read()
@property
def is_runnable(self) -> bool:
'''
Whether the Module is runnable, in other words, whether can we execute the Module through the
`hub run` command.
'''
return self._run_func != None
class Module(object): class Module(object):
''' '''
In PaddleHub, Module represents an executable module, which usually a pre-trained model that can be used for end-to-end
prediction, such as a face detection model or a lexical analysis model, or a pre-trained model that requires finetuning,
such as BERT/ERNIE. When loading a Module with a specified name, if the Module does not exist locally, PaddleHub will
automatically request the server or the specified Git source to download the resource.
Args:
name(str): Module name.
directory(str|optional): Directory of the module to be loaded, only takes effect when the `name` is not specified.
version(str|optional): The version limit of the module, only takes effect when the `name` is specified. When the local
Module does not meet the specified version conditions, PaddleHub will re-request the server to
download the appropriate Module. Default to None, This means that the local Module will be used.
If the Module does not exist, PaddleHub will download the latest version available from the
server according to the usage environment.
source(str|optional): Url of a git repository. If this parameter is specified, PaddleHub will no longer download the
specified Module from the default server, but will look for it in the specified repository.
Default to None.
update(bool|optional): Whether to update the locally cached git repository, only takes effect when the `source`
is specified. Default to False.
branch(str|optional): The branch of the specified git repository. Default to None.
''' '''
def __new__(cls, def __new__(cls,
...@@ -72,11 +154,13 @@ class Module(object): ...@@ -72,11 +154,13 @@ class Module(object):
version: str = None, version: str = None,
source: str = None, source: str = None,
update: bool = False, update: bool = False,
branch: str = None,
**kwargs): **kwargs):
if cls.__name__ == 'Module': if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name(name=name, version=version, source=source, update=update, **kwargs) module = cls.init_with_name(
name=name, version=version, source=source, update=update, branch=branch, **kwargs)
elif directory: elif directory:
module = cls.init_with_directory(directory=directory, **kwargs) module = cls.init_with_directory(directory=directory, **kwargs)
else: else:
...@@ -86,8 +170,7 @@ class Module(object): ...@@ -86,8 +170,7 @@ class Module(object):
@classmethod @classmethod
def load(cls, directory: str) -> Generic: def load(cls, directory: str) -> Generic:
''' '''Load the Module object defined in the specified directory.'''
'''
if directory.endswith(os.sep): if directory.endswith(os.sep):
directory = directory[:-1] directory = directory[:-1]
...@@ -122,6 +205,7 @@ class Module(object): ...@@ -122,6 +205,7 @@ class Module(object):
@classmethod @classmethod
def load_module_info(cls, directory: str) -> EasyDict: def load_module_info(cls, directory: str) -> EasyDict:
'''Load the infomation of Module object defined in the specified directory.'''
# If is ModuleV1 # If is ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file): if os.path.exists(desc_file):
...@@ -153,9 +237,8 @@ class Module(object): ...@@ -153,9 +237,8 @@ class Module(object):
source: str = None, source: str = None,
update: bool = False, update: bool = False,
branch: str = None, branch: str = None,
**kwargs): **kwargs) -> Union[RunModule, ModuleV1]:
''' '''Initialize Module according to the specified name.'''
'''
from paddlehub.module.manager import LocalModuleManager from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name, source=source, branch=branch) user_module_cls = manager.search(name, source=source, branch=branch)
...@@ -181,9 +264,8 @@ class Module(object): ...@@ -181,9 +264,8 @@ class Module(object):
return user_module_cls(**kwargs) return user_module_cls(**kwargs)
@classmethod @classmethod
def init_with_directory(cls, directory: str, **kwargs): def init_with_directory(cls, directory: str, **kwargs) -> Union[RunModule, ModuleV1]:
''' '''Initialize Module according to the specified directory.'''
'''
user_module_cls = cls.load(directory) user_module_cls = cls.load(directory)
# The HubModule in the old version will use the _initialize method to initialize, # The HubModule in the old version will use the _initialize method to initialize,
...@@ -202,77 +284,6 @@ class Module(object): ...@@ -202,77 +284,6 @@ class Module(object):
user_module_cls.directory = directory user_module_cls.directory = directory
return user_module_cls(**kwargs) return user_module_cls(**kwargs)
@classmethod
def get_py_requirements(cls):
'''
'''
req_file = os.path.join(cls.directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read().split('\n')
class RunModule(object):
'''
'''
def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize:
return
super(RunModule, self).__init__()
_run_func_name = self._get_func_name(self.__class__, _module_runnable_func)
self._run_func = getattr(self, _run_func_name) if _run_func_name else None
self._serving_func_name = self._get_func_name(self.__class__, _module_serving_func)
self._is_initialize = True
def _get_func_name(self, current_cls: Generic, module_func_dict: dict) -> Optional[str]:
mod = current_cls.__module__ + '.' + current_cls.__name__
if mod in module_func_dict:
_func_name = module_func_dict[mod]
return _func_name
elif current_cls.__bases__:
for base_class in current_cls.__bases__:
return self._get_func_name(base_class, module_func_dict)
else:
return None
# After the 2.0.0rc version, paddle uses the dynamic graph mode by default, which will cause the
# execution of the static graph model to fail, so compatibility protection is required.
def __getattribute__(self, attr):
_attr = object.__getattribute__(self, attr)
# If the acquired attribute is a built-in property of the object, skip it.
if re.match('__.*__', attr):
return _attr
# If the module is a dygraph model, skip it.
elif isinstance(self, paddle.nn.Layer):
return _attr
# If the acquired attribute is not a class method, skip it.
elif not inspect.ismethod(_attr):
return _attr
return paddle_utils.run_in_static_mode(_attr)
@classmethod
def get_py_requirements(cls) -> List[str]:
'''
'''
py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt')
if not os.path.exists(req_file):
return []
with open(req_file, 'r') as file:
return file.read()
@property
def is_runnable(self) -> bool:
return self._run_func != None
def moduleinfo(name: str, def moduleinfo(name: str,
version: str, version: str,
...@@ -282,6 +293,8 @@ def moduleinfo(name: str, ...@@ -282,6 +293,8 @@ def moduleinfo(name: str,
type: str = None, type: str = None,
meta=None) -> Callable: meta=None) -> Callable:
''' '''
Mark Module information for a python class, and the class will automatically be extended to inherit HubModule. In other words, python classes
marked with moduleinfo can be loaded through hub.Module.
''' '''
def _wrapper(cls: Generic) -> Generic: def _wrapper(cls: Generic) -> Generic:
......
...@@ -55,6 +55,7 @@ class GitSource(object): ...@@ -55,6 +55,7 @@ class GitSource(object):
self.load_hub_modules() self.load_hub_modules()
def checkout(self, branch: str): def checkout(self, branch: str):
'''Checkout the current repo to the specified branch.'''
try: try:
self.repo.git.checkout(branch) self.repo.git.checkout(branch)
# reload modules # reload modules
...@@ -63,11 +64,12 @@ class GitSource(object): ...@@ -63,11 +64,12 @@ class GitSource(object):
utils.record_exception('An error occurred while checkout {}.'.format(self.path)) utils.record_exception('An error occurred while checkout {}.'.format(self.path))
def update(self): def update(self):
'''Update the current repo.'''
try: try:
self.repo.remote().pull(self.repo.branches[0]) self.repo.remote().pull(self.repo.branches[0])
# reload modules # reload modules
self.load_hub_modules() self.load_hub_modules()
except Exception as e: except:
self.hub_modules = OrderedDict() self.hub_modules = OrderedDict()
utils.record_exception('An error occurred while update {}.'.format(self.path)) utils.record_exception('An error occurred while update {}.'.format(self.path))
...@@ -82,7 +84,7 @@ class GitSource(object): ...@@ -82,7 +84,7 @@ class GitSource(object):
_item = py_module.__dict__[_item] _item = py_module.__dict__[_item]
if issubclass(_item, RunModule): if issubclass(_item, RunModule):
self.hub_modules[_item.name] = _item self.hub_modules[_item.name] = _item
except Exception as e: except:
self.hub_modules = OrderedDict() self.hub_modules = OrderedDict()
utils.record_exception('An error occurred while loading {}.'.format(self.path)) utils.record_exception('An error occurred while loading {}.'.format(self.path))
...@@ -118,6 +120,10 @@ class GitSource(object): ...@@ -118,6 +120,10 @@ class GitSource(object):
}] }]
return None return None
def get_module_compat_info(self, name: str) -> dict:
'''Get the version compatibility information of the model.'''
return {}
@classmethod @classmethod
def check(cls, url: str) -> bool: def check(cls, url: str) -> bool:
''' '''
......
...@@ -51,14 +51,14 @@ class HubServer(object): ...@@ -51,14 +51,14 @@ class HubServer(object):
self.sources.pop(key) self.sources.pop(key)
def get_source(self, url: str): def get_source(self, url: str):
'''''' '''Get a module source by url'''
key = self.keysmap.get(url) key = self.keysmap.get(url)
if not key: if not key:
return None return None
return self.sources.get(key) return self.sources.get(key)
def get_source_by_key(self, key: str): def get_source_by_key(self, key: str):
'''''' '''Get a module source by key'''
return self.sources.get(key) return self.sources.get(key)
def search_module(self, def search_module(self,
...@@ -105,9 +105,8 @@ class HubServer(object): ...@@ -105,9 +105,8 @@ class HubServer(object):
return result return result
return [] return []
def get_module_info(self, name: str, source: str = None) -> dict: def get_module_compat_info(self, name: str, source: str = None) -> dict:
''' '''Get the version compatibility information of the model.'''
'''
sources = self.sources.values() if not source else [self._generate_source(source)] sources = self.sources.values() if not source else [self._generate_source(source)]
for source in sources: for source in sources:
result = source.get_module_info(name=name) result = source.get_module_info(name=name)
......
...@@ -18,7 +18,7 @@ import requests ...@@ -18,7 +18,7 @@ import requests
from typing import List from typing import List
import paddlehub import paddlehub
from paddlehub.utils import utils, platform from paddlehub.utils import platform
class ServerConnectionError(Exception): class ServerConnectionError(Exception):
...@@ -79,9 +79,8 @@ class ServerSource(object): ...@@ -79,9 +79,8 @@ class ServerSource(object):
return result['data'] return result['data']
return None return None
def get_module_info(self, name: str) -> dict: def get_module_compat_info(self, name: str) -> dict:
''' '''Get the version compatibility information of the model.'''
'''
def _convert_version(version: str) -> List: def _convert_version(version: str) -> List:
result = [] result = []
...@@ -112,8 +111,7 @@ class ServerSource(object): ...@@ -112,8 +111,7 @@ class ServerSource(object):
return {} return {}
def request(self, path: str, params: dict) -> dict: def request(self, path: str, params: dict) -> dict:
''' '''Request server.'''
'''
api = '{}/{}'.format(self._url, path) api = '{}/{}'.format(self._url, path)
try: try:
result = requests.get(api, params, timeout=self._timeout) result = requests.get(api, params, timeout=self._timeout)
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from paddlehub.env import DATA_HOME from paddlehub.env import DATA_HOME
......
...@@ -113,6 +113,32 @@ class Logger(object): ...@@ -113,6 +113,32 @@ class Logger(object):
yield yield
self.handler.terminator = old_terminator self.handler.terminator = old_terminator
@contextlib.contextmanager
def processing(self, msg: str, interval: float = 0.1):
'''
Continuously print a progress bar with rotating special effects.
Args:
msg(str): Message to be printed.
interval(float): Rotation interval. Default to 0.1.
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with self.use_terminator('\r'):
self.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class ProgressBar(object): class ProgressBar(object):
''' '''
...@@ -172,28 +198,6 @@ class ProgressBar(object): ...@@ -172,28 +198,6 @@ class ProgressBar(object):
sys.stdout.write('\n') sys.stdout.write('\n')
@contextlib.contextmanager
def processing(msg: str, interval: float = 0.1):
'''
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with logger.use_terminator('\r'):
logger.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class FormattedText(object): class FormattedText(object):
''' '''
Cross-platform formatted string Cross-platform formatted string
......
...@@ -34,8 +34,7 @@ class ResourceNotFoundError(Exception): ...@@ -34,8 +34,7 @@ class ResourceNotFoundError(Exception):
def download(name: str, save_path: str, version: str = None): def download(name: str, save_path: str, version: str = None):
''' '''The download interface provided to PaddleX for downloading the specified model and resource files.'''
'''
file = os.path.join(save_path, name) file = os.path.join(save_path, name)
file = os.path.realpath(file) file = os.path.realpath(file)
if os.path.exists(file): if os.path.exists(file):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import codecs import codecs
import sys
from typing import List from typing import List
import yaml import yaml
......
...@@ -13,12 +13,8 @@ ...@@ -13,12 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import optparse
import os
import pip import pip
import sys
from pip._internal.utils.misc import get_installed_distributions from pip._internal.utils.misc import get_installed_distributions
from typing import List, Tuple
from paddlehub.utils.utils import Version from paddlehub.utils.utils import Version
from paddlehub.utils.io import discard_oe, typein from paddlehub.utils.io import discard_oe, typein
......
...@@ -246,16 +246,14 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: ...@@ -246,16 +246,14 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
def get_platform_default_encoding() -> str: def get_platform_default_encoding() -> str:
''' '''Get the default encoding of the current platform.'''
'''
if utils.platform.is_windows(): if utils.platform.is_windows():
return 'gbk' return 'gbk'
return 'utf8' return 'utf8'
def sys_stdin_encoding() -> str: def sys_stdin_encoding() -> str:
''' '''Get the standary input stream default encoding.'''
'''
encoding = sys.stdin.encoding encoding = sys.stdin.encoding
if encoding is None: if encoding is None:
encoding = sys.getdefaultencoding() encoding = sys.getdefaultencoding()
...@@ -266,8 +264,7 @@ def sys_stdin_encoding() -> str: ...@@ -266,8 +264,7 @@ def sys_stdin_encoding() -> str:
def sys_stdout_encoding() -> str: def sys_stdout_encoding() -> str:
''' '''Get the standary output stream default encoding.'''
'''
encoding = sys.stdout.encoding encoding = sys.stdout.encoding
if encoding is None: if encoding is None:
encoding = sys.getdefaultencoding() encoding = sys.getdefaultencoding()
...@@ -278,15 +275,13 @@ def sys_stdout_encoding() -> str: ...@@ -278,15 +275,13 @@ def sys_stdout_encoding() -> str:
def md5(text: str): def md5(text: str):
''' '''Calculate the md5 value of the input text.'''
'''
md5code = hashlib.md5(text.encode()) md5code = hashlib.md5(text.encode())
return md5code.hexdigest() return md5code.hexdigest()
def record(msg: str) -> str: def record(msg: str) -> str:
''' '''Record the specified text into the PaddleHub log file witch will be automatically stored according to date.'''
'''
logfile = os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log')) logfile = os.path.join(hubenv.LOG_HOME, time.strftime('%Y%m%d.log'))
with open(logfile, 'a') as file: with open(logfile, 'a') as file:
file.write('=' * 50 + '\n') file.write('=' * 50 + '\n')
...@@ -298,8 +293,7 @@ def record(msg: str) -> str: ...@@ -298,8 +293,7 @@ def record(msg: str) -> str:
def record_exception(msg: str) -> str: def record_exception(msg: str) -> str:
''' '''Record the current exception infomation into the PaddleHub log file witch will be automatically stored according to date.'''
'''
tb = traceback.format_exc() tb = traceback.format_exc()
file = record(tb) file = record(tb)
utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file)) utils.log.logger.warning('{}. Detailed error information can be found in the {}.'.format(msg, file))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册