diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index ec027d275ba7e2fed33b2b5958db1e6b799225dd..b44bcc4ab18d1d0a891365990352a2a812b97a09 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -13,9 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys + __version__ = '2.0.0a0' +from paddlehub.utils import log, parser, utils from paddlehub.module import Module - +# In order to maintain the compatibility of the old version, we put the relevant +# compatible code in the paddlehub/compat package, and mapped some modules referenced +# in the old version +from paddlehub.compat import paddle_utils from paddlehub.compat.module.processor import BaseProcessor +from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule from paddlehub.compat.type import DataType + +sys.modules['paddlehub.io.parser'] = parser +sys.modules['paddlehub.common.logger'] = log +sys.modules['paddlehub.common.paddle_helper'] = paddle_utils +sys.modules['paddlehub.common.utils'] = utils diff --git a/paddlehub/compat/module/module_desc.proto b/paddlehub/compat/module/module_desc.proto new file mode 100644 index 0000000000000000000000000000000000000000..1108d586f4b9d32be98673a3f812aa914999f407 --- /dev/null +++ b/paddlehub/compat/module/module_desc.proto @@ -0,0 +1,81 @@ +// Copyright 2018 The Paddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +syntax = "proto3"; +option optimize_for = LITE_RUNTIME; + +package paddlehub.module.desc; + +enum DataType { + NONE = 0; + INT = 1; + FLOAT = 2; + STRING = 3; + BOOLEAN = 4; + LIST = 5; + MAP = 6; + SET = 7; + OBJECT = 8; +} + +message KVData { + map key_type = 1; + map data = 2; +} + +message ModuleAttr { + // Basic type + DataType type = 1; + int64 i = 2; + double f = 3; + bool b = 4; + string s = 5; + KVData map = 6; + KVData list = 7; + KVData set = 8; + KVData object = 9; + // + string name = 10; + string info = 11; + +} + +// Feed Variable Description +message FeedDesc { + string var_name = 1; + string alias = 2; +}; + +// Fetch Variable Description +message FetchDesc { + string var_name = 1; + string alias = 2; +}; + +// Module Variable +message ModuleVar { + repeated FetchDesc fetch_desc = 1; + repeated FeedDesc feed_desc = 2; +} + +// A Hub Module is stored in a directory with a file 'module_desc.pb' +// containing a serialized protocol message of this type. The further contents +// of the directory depend on the storage format described by the message. +message ModuleDesc { + // signature to module variable + map sign2var = 2; + + ModuleAttr attr = 3; +}; diff --git a/paddlehub/compat/module/module_v1.py b/paddlehub/compat/module/module_v1.py index d3f40ac6e4f0846cac10e55c608147a94e602966..926f9d77f978a98c05fb9e4bb2d0a43a0367df02 100644 --- a/paddlehub/compat/module/module_v1.py +++ b/paddlehub/compat/module/module_v1.py @@ -47,6 +47,10 @@ class ModuleV1(object): self._generate_func() def _load_processor(self): + # Some module does not have a processor(e.g. ernie) + if not 'processor_info' in self.desc: + return + python_path = os.path.join(self.directory, 'python') processor_name = self.desc.processor_info self.processor = utils.load_py_module(python_path, processor_name) diff --git a/paddlehub/compat/paddle_utils.py b/paddlehub/compat/paddle_utils.py index f6a35533a26c5cce8933d9fcac65c90225683741..809ff5b9b90a183fddd532cb8884745c34767120 100644 --- a/paddlehub/compat/paddle_utils.py +++ b/paddlehub/compat/paddle_utils.py @@ -13,8 +13,64 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +from typing import Callable, List + import paddle +from paddlehub.utils.utils import Version + +dtype_map = { + paddle.device.core.VarDesc.VarType.FP32: "float32", + paddle.device.core.VarDesc.VarType.FP64: "float64", + paddle.device.core.VarDesc.VarType.FP16: "float16", + paddle.device.core.VarDesc.VarType.INT32: "int32", + paddle.device.core.VarDesc.VarType.INT16: "int16", + paddle.device.core.VarDesc.VarType.INT64: "int64", + paddle.device.core.VarDesc.VarType.BOOL: "bool", + paddle.device.core.VarDesc.VarType.INT16: "int16", + paddle.device.core.VarDesc.VarType.UINT8: "uint8", + paddle.device.core.VarDesc.VarType.INT8: "int8", +} + + +def convert_dtype_to_string(dtype: str) -> paddle.device.core.VarDesc.VarType: + if dtype in dtype_map: + return dtype_map[dtype] + raise TypeError("dtype shoule in %s" % list(dtype_map.keys())) + + +def get_variable_info(var: paddle.Variable) -> dict: + if not isinstance(var, paddle.Variable): + raise TypeError("var shoule be an instance of paddle.Variable") + + var_info = { + 'name': var.name, + 'stop_gradient': var.stop_gradient, + 'is_data': var.is_data, + 'error_clip': var.error_clip, + 'type': var.type + } + + try: + var_info['dtype'] = convert_dtype_to_string(var.dtype) + var_info['lod_level'] = var.lod_level + var_info['shape'] = var.shape + except: + pass + + if isinstance(var, paddle.device.framework.Parameter): + var_info['trainable'] = var.trainable + var_info['optimize_attr'] = var.optimize_attr + var_info['regularizer'] = var.regularizer + if Version(paddle.__version__) < '1.8': + var_info['gradient_clip_attr'] = var.gradient_clip_attr + var_info['do_model_average'] = var.do_model_average + else: + var_info['persistable'] = var.persistable + + return var_info + def remove_feed_fetch_op(program: paddle.static.Program): '''Remove feed and fetch operator and variable for fine-tuning.''' @@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program): block._remove_var(var) program.desc.flush() + + +def rename_var(block: paddle.device.framework.Block, old_name: str, new_name: str): + ''' + ''' + for op in block.ops: + for input_name in op.input_arg_names: + if input_name == old_name: + op._rename_input(old_name, new_name) + + for output_name in op.output_arg_names: + if output_name == old_name: + op._rename_output(old_name, new_name) + + block._rename_var(old_name, new_name) + + +def add_vars_prefix(program: paddle.static.Program, + prefix: str, + vars: List[paddle.Variable] = None, + excludes: Callable = None): + ''' + ''' + block = program.global_block() + vars = list(vars) if vars else list(block.vars.keys()) + vars = [var for var in vars if var not in excludes] if excludes else vars + for var in vars: + rename_var(block, var, prefix + var) + + +def remove_vars_prefix(program: paddle.static.Program, + prefix: str, + vars: List[paddle.Variable] = None, + excludes: Callable = None): + ''' + ''' + block = program.global_block() + vars = [var for var in vars + if var.startswith(prefix)] if vars else [var for var in block.vars.keys() if var.startswith(prefix)] + vars = [var for var in vars if var not in excludes] if excludes else vars + for var in vars: + rename_var(block, var, var.replace(prefix, '', 1)) + + +def clone_program(origin_program: paddle.static.Program, for_test: bool = False) -> paddle.static.Program: + dest_program = paddle.static.Program() + + _copy_vars_and_ops_in_blocks(origin_program.global_block(), dest_program.global_block()) + + dest_program = dest_program.clone(for_test=for_test) + if not for_test: + for name, var in origin_program.global_block().vars.items(): + dest_program.global_block().vars[name].stop_gradient = var.stop_gradient + + return dest_program + + +def _copy_vars_and_ops_in_blocks(from_block: paddle.device.framework.Block, to_block: paddle.device.framework.Block): + for var in from_block.vars: + var = from_block.var(var) + var_info = copy.deepcopy(get_variable_info(var)) + if isinstance(var, paddle.device.framework.Parameter): + to_block.create_parameter(**var_info) + else: + to_block.create_var(**var_info) + + for op in from_block.ops: + all_attrs = op.all_attrs() + if 'sub_block' in all_attrs: + _sub_block = to_block.program._create_block() + _copy_vars_and_ops_in_blocks(all_attrs['sub_block'], _sub_block) + to_block.program._rollback() + new_attrs = {'sub_block': _sub_block} + for key, value in all_attrs.items(): + if key == 'sub_block': + continue + new_attrs[key] = copy.deepcopy(value) + else: + new_attrs = copy.deepcopy(all_attrs) + + op_info = { + 'type': op.type, + 'inputs': + {input: [to_block._find_var_recursive(var) for var in op.input(input)] + for input in op.input_names}, + 'outputs': + {output: [to_block._find_var_recursive(var) for var in op.output(output)] + for output in op.output_names}, + 'attrs': new_attrs + } + to_block.append_op(**op_info) + + +def set_op_attr(program: paddle.static.Program, is_test: bool = False): + for block in program.blocks: + for op in block.ops: + if not op.has_attr('is_test'): + continue + + op._set_attr('is_test', is_test) diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index 3d2992cdd9db896eca8873aad07f1da8264c5fc5..ecd7b6d92f183bbb4ef11748575a57065f431c36 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception): class LocalModuleManager(object): - """ + ''' LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation, and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the LocalModuleManager object of the same home directory is generated multiple times, the same object is returned. Args: home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules - """ + ''' _instance_map = {} def __new__(cls, home: str = MODULE_HOME): diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index 2cc6e5601716918eab55afba06418bbf952da0ed..d3dfdf454445ddfa8f55d42dbe92a3669922786d 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -17,9 +17,9 @@ import inspect import importlib import os import sys -from typing import Callable, List, Optional, Generic +from typing import Callable, Generic, List, Optional -from paddlehub.utils import utils +from paddlehub.utils import log, utils from paddlehub.compat.module.module_v1 import ModuleV1 @@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable: class Module(object): ''' ''' + def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs): if cls.__name__ == 'Module': - # This branch come from hub.Module(name='xxx') + # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx') if name: module = cls.init_with_name(name=name, version=version, **kwargs) elif directory: @@ -72,19 +73,19 @@ class Module(object): @classmethod def load(cls, directory: str) -> Generic: + ''' + ''' if directory.endswith(os.sep): directory = directory[:-1] - # if module description file existed, try to load as ModuleV1 + # If module description file existed, try to load as ModuleV1 desc_file = os.path.join(directory, 'module_desc.pb') if os.path.exists(desc_file): return ModuleV1.load(desc_file) basename = os.path.split(directory)[-1] dirname = os.path.join(*list(os.path.split(directory)[:-1])) - - sys.path.insert(0, dirname) - py_module = importlib.import_module('{}.module'.format(basename)) + py_module = utils.load_py_module(dirname, '{}.module'.format(basename)) for _item, _cls in inspect.getmembers(py_module, inspect.isclass): _item = py_module.__dict__[_item] @@ -93,13 +94,14 @@ class Module(object): break else: raise InvalidHubModule(directory) - sys.path.pop(0) user_module_cls.directory = directory return user_module_cls @classmethod def init_with_name(cls, name: str, version: str = None, **kwargs): + ''' + ''' from paddlehub.module.manager import LocalModuleManager manager = LocalModuleManager() user_module_cls = manager.search(name) @@ -107,15 +109,39 @@ class Module(object): user_module_cls = manager.install(name, version) directory = manager._get_normalized_path(name) + + # The HubModule in the old version will use the _initialize method to initialize, + # this function will be obsolete in a future version + if hasattr(user_module_cls, '_initialize'): + log.logger.warning( + 'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object' + ) + user_module = user_module_cls(directory=directory) + user_module._initialize(**kwargs) + return user_module return user_module_cls(directory=directory, **kwargs) @classmethod def init_with_directory(cls, directory: str, **kwargs): + ''' + ''' user_module_cls = cls.load(directory) - return user_module_cls(**kwargs) + + # The HubModule in the old version will use the _initialize method to initialize, + # this function will be obsolete in a future version + if hasattr(user_module_cls, '_initialize'): + log.logger.warning( + 'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object' + ) + user_module = user_module_cls(directory=directory) + user_module._initialize(**kwargs) + return user_module + return user_module_cls(directory=directory, **kwargs) @classmethod def get_py_requirements(cls): + ''' + ''' req_file = os.path.join(cls.directory, 'requirements.txt') if not os.path.exists(req_file): return [] @@ -125,6 +151,9 @@ class Module(object): class RunModule(object): + ''' + ''' + def __init__(self, *args, **kwargs): # Avoid module being initialized multiple times if '_is_initialize' in self.__dict__ and self._is_initialize: @@ -149,6 +178,8 @@ class RunModule(object): @classmethod def get_py_requirements(cls) -> List[str]: + ''' + ''' py_module = sys.modules[cls.__module__] directory = os.path.dirname(py_module.__file__) req_file = os.path.join(directory, 'requirements.txt') @@ -172,6 +203,9 @@ def moduleinfo(name: str, summary: str = None, type: str = None, meta=None) -> Callable: + ''' + ''' + def _wrapper(cls: Generic) -> Generic: wrap_cls = cls _meta = RunModule if not meta else meta diff --git a/paddlehub/utils/log.py b/paddlehub/utils/log.py index 18f9c4f249a63ef871827c126f178a9a347e93d6..8165300e309fffa4a59a5dde99ed287309ba2d48 100644 --- a/paddlehub/utils/log.py +++ b/paddlehub/utils/log.py @@ -170,8 +170,8 @@ class FormattedText(object): self.width = width def __repr__(self) -> str: - form = ':{}{}'.format(self.align, self.width) - text = ('{' + form + '}').format(self.text) + form = '{{:{}{}}}'.format(self.align, self.width) + text = form.format(self.text) if not self.color: return text return self.color + text + Fore.RESET diff --git a/paddlehub/utils/parser.py b/paddlehub/utils/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6125c44256c83c3a5579fb3561109d02b70618 --- /dev/null +++ b/paddlehub/utils/parser.py @@ -0,0 +1,75 @@ +# coding:utf-8 +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import sys +from typing import List + +import yaml + +from paddlehub.utils.utils import sys_stdin_encoding + + +class CSVFileParser(object): + def parse(self, csv_file: str) -> dict: + with codecs.open(csv_file, 'r', sys_stdin_encoding()) as file: + content = file.read() + content = content.split('\n') + self.title = content[0].split(',') + self.content = {} + for key in self.title: + self.content[key] = [] + + for text in content[1:]: + if (text == ''): + continue + + for index, item in enumerate(text.split(',')): + title = self.title[index] + self.content[title].append(item) + + return self.content + + +class YAMLFileParser(object): + def parse(self, yaml_file: str) -> dict: + with codecs.open(yaml_file, 'r', sys_stdin_encoding()) as file: + content = file.read() + return yaml.load(content, Loader=yaml.BaseLoader) + + +class TextFileParser(object): + def parse(self, txt_file: str, use_strip: bool = True) -> List: + contents = [] + try: + with codecs.open(txt_file, 'r', encoding='utf8') as file: + for line in file: + if use_strip: + line = line.strip() + if line: + contents.append(line) + except: + with codecs.open(txt_file, 'r', encoding='gbk') as file: + for line in file: + if use_strip: + line = line.strip() + if line: + contents.append(line) + return contents + + +csv_parser = CSVFileParser() +yaml_parser = YAMLFileParser() +txt_parser = TextFileParser() diff --git a/paddlehub/utils/utils.py b/paddlehub/utils/utils.py index 995e196b228c7a88fdee05792ef6f1b71950181b..13e7dd4bd5e1157f3a1867e2e2d1f0195016033a 100644 --- a/paddlehub/utils/utils.py +++ b/paddlehub/utils/utils.py @@ -31,10 +31,12 @@ from urllib.parse import urlparse import packaging.version import paddlehub.env as hubenv +import paddlehub.utils as utils class Version(packaging.version.Version): '''Extended implementation of packaging.version.Version''' + def match(self, condition: str) -> bool: ''' Determine whether the given condition are met @@ -76,9 +78,35 @@ class Version(packaging.version.Version): return _comp(Version(version)) + def __lt__(self, other): + if isinstance(other, str): + other = Version(other) + return super().__lt__(other) + + def __le__(self, other): + if isinstance(other, str): + other = Version(other) + return super().__le__(other) + + def __gt__(self, other): + if isinstance(other, str): + other = Version(other) + return super().__gt__(other) + + def __ge__(self, other): + if isinstance(other, str): + other = Version(other) + return super().__ge__(other) + + def __eq__(self, other): + if isinstance(other, str): + other = Version(other) + return super().__eq__(other) + class Timer(object): '''Calculate runing speed and estimated time of arrival(ETA)''' + def __init__(self, total_step: int): self.total_step = total_step self.last_start_step = 0 @@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: sys.path.pop(0) return py_module + + +def get_platform_default_encoding() -> str: + ''' + ''' + if utils.platform.is_windows(): + return 'gbk' + return 'utf8' + + +def sys_stdin_encoding() -> str: + ''' + ''' + encoding = sys.stdin.encoding + if encoding is None: + encoding = sys.getdefaultencoding() + + if encoding is None: + encoding = get_platform_default_encoding() + return encoding + + +def sys_stdout_encoding() -> str: + ''' + ''' + encoding = sys.stdout.encoding + if encoding is None: + encoding = sys.getdefaultencoding() + + if encoding is None: + encoding = get_platform_default_encoding() + return encoding