提交 6acb2dd4 编写于 作者: W wuzewu

Fix module compat bug

上级 b33e4d14
...@@ -13,9 +13,21 @@ ...@@ -13,9 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
__version__ = '2.0.0a0' __version__ = '2.0.0a0'
from paddlehub.utils import log, parser, utils
from paddlehub.module import Module from paddlehub.module import Module
# In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub/compat package, and mapped some modules referenced
# in the old version
from paddlehub.compat import paddle_utils
from paddlehub.compat.module.processor import BaseProcessor from paddlehub.compat.module.processor import BaseProcessor
from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule
from paddlehub.compat.type import DataType from paddlehub.compat.type import DataType
sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
package paddlehub.module.desc;
enum DataType {
NONE = 0;
INT = 1;
FLOAT = 2;
STRING = 3;
BOOLEAN = 4;
LIST = 5;
MAP = 6;
SET = 7;
OBJECT = 8;
}
message KVData {
map<string, DataType> key_type = 1;
map<string, ModuleAttr> data = 2;
}
message ModuleAttr {
// Basic type
DataType type = 1;
int64 i = 2;
double f = 3;
bool b = 4;
string s = 5;
KVData map = 6;
KVData list = 7;
KVData set = 8;
KVData object = 9;
//
string name = 10;
string info = 11;
}
// Feed Variable Description
message FeedDesc {
string var_name = 1;
string alias = 2;
};
// Fetch Variable Description
message FetchDesc {
string var_name = 1;
string alias = 2;
};
// Module Variable
message ModuleVar {
repeated FetchDesc fetch_desc = 1;
repeated FeedDesc feed_desc = 2;
}
// A Hub Module is stored in a directory with a file 'module_desc.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message ModuleDesc {
// signature to module variable
map<string, ModuleVar> sign2var = 2;
ModuleAttr attr = 3;
};
...@@ -47,6 +47,10 @@ class ModuleV1(object): ...@@ -47,6 +47,10 @@ class ModuleV1(object):
self._generate_func() self._generate_func()
def _load_processor(self): def _load_processor(self):
# Some module does not have a processor(e.g. ernie)
if not 'processor_info' in self.desc:
return
python_path = os.path.join(self.directory, 'python') python_path = os.path.join(self.directory, 'python')
processor_name = self.desc.processor_info processor_name = self.desc.processor_info
self.processor = utils.load_py_module(python_path, processor_name) self.processor = utils.load_py_module(python_path, processor_name)
......
...@@ -13,8 +13,64 @@ ...@@ -13,8 +13,64 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from typing import Callable, List
import paddle import paddle
from paddlehub.utils.utils import Version
dtype_map = {
paddle.device.core.VarDesc.VarType.FP32: "float32",
paddle.device.core.VarDesc.VarType.FP64: "float64",
paddle.device.core.VarDesc.VarType.FP16: "float16",
paddle.device.core.VarDesc.VarType.INT32: "int32",
paddle.device.core.VarDesc.VarType.INT16: "int16",
paddle.device.core.VarDesc.VarType.INT64: "int64",
paddle.device.core.VarDesc.VarType.BOOL: "bool",
paddle.device.core.VarDesc.VarType.INT16: "int16",
paddle.device.core.VarDesc.VarType.UINT8: "uint8",
paddle.device.core.VarDesc.VarType.INT8: "int8",
}
def convert_dtype_to_string(dtype: str) -> paddle.device.core.VarDesc.VarType:
if dtype in dtype_map:
return dtype_map[dtype]
raise TypeError("dtype shoule in %s" % list(dtype_map.keys()))
def get_variable_info(var: paddle.Variable) -> dict:
if not isinstance(var, paddle.Variable):
raise TypeError("var shoule be an instance of paddle.Variable")
var_info = {
'name': var.name,
'stop_gradient': var.stop_gradient,
'is_data': var.is_data,
'error_clip': var.error_clip,
'type': var.type
}
try:
var_info['dtype'] = convert_dtype_to_string(var.dtype)
var_info['lod_level'] = var.lod_level
var_info['shape'] = var.shape
except:
pass
if isinstance(var, paddle.device.framework.Parameter):
var_info['trainable'] = var.trainable
var_info['optimize_attr'] = var.optimize_attr
var_info['regularizer'] = var.regularizer
if Version(paddle.__version__) < '1.8':
var_info['gradient_clip_attr'] = var.gradient_clip_attr
var_info['do_model_average'] = var.do_model_average
else:
var_info['persistable'] = var.persistable
return var_info
def remove_feed_fetch_op(program: paddle.static.Program): def remove_feed_fetch_op(program: paddle.static.Program):
'''Remove feed and fetch operator and variable for fine-tuning.''' '''Remove feed and fetch operator and variable for fine-tuning.'''
...@@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program): ...@@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program):
block._remove_var(var) block._remove_var(var)
program.desc.flush() program.desc.flush()
def rename_var(block: paddle.device.framework.Block, old_name: str, new_name: str):
'''
'''
for op in block.ops:
for input_name in op.input_arg_names:
if input_name == old_name:
op._rename_input(old_name, new_name)
for output_name in op.output_arg_names:
if output_name == old_name:
op._rename_output(old_name, new_name)
block._rename_var(old_name, new_name)
def add_vars_prefix(program: paddle.static.Program,
prefix: str,
vars: List[paddle.Variable] = None,
excludes: Callable = None):
'''
'''
block = program.global_block()
vars = list(vars) if vars else list(block.vars.keys())
vars = [var for var in vars if var not in excludes] if excludes else vars
for var in vars:
rename_var(block, var, prefix + var)
def remove_vars_prefix(program: paddle.static.Program,
prefix: str,
vars: List[paddle.Variable] = None,
excludes: Callable = None):
'''
'''
block = program.global_block()
vars = [var for var in vars
if var.startswith(prefix)] if vars else [var for var in block.vars.keys() if var.startswith(prefix)]
vars = [var for var in vars if var not in excludes] if excludes else vars
for var in vars:
rename_var(block, var, var.replace(prefix, '', 1))
def clone_program(origin_program: paddle.static.Program, for_test: bool = False) -> paddle.static.Program:
dest_program = paddle.static.Program()
_copy_vars_and_ops_in_blocks(origin_program.global_block(), dest_program.global_block())
dest_program = dest_program.clone(for_test=for_test)
if not for_test:
for name, var in origin_program.global_block().vars.items():
dest_program.global_block().vars[name].stop_gradient = var.stop_gradient
return dest_program
def _copy_vars_and_ops_in_blocks(from_block: paddle.device.framework.Block, to_block: paddle.device.framework.Block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, paddle.device.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
all_attrs = op.all_attrs()
if 'sub_block' in all_attrs:
_sub_block = to_block.program._create_block()
_copy_vars_and_ops_in_blocks(all_attrs['sub_block'], _sub_block)
to_block.program._rollback()
new_attrs = {'sub_block': _sub_block}
for key, value in all_attrs.items():
if key == 'sub_block':
continue
new_attrs[key] = copy.deepcopy(value)
else:
new_attrs = copy.deepcopy(all_attrs)
op_info = {
'type': op.type,
'inputs':
{input: [to_block._find_var_recursive(var) for var in op.input(input)]
for input in op.input_names},
'outputs':
{output: [to_block._find_var_recursive(var) for var in op.output(output)]
for output in op.output_names},
'attrs': new_attrs
}
to_block.append_op(**op_info)
def set_op_attr(program: paddle.static.Program, is_test: bool = False):
for block in program.blocks:
for op in block.ops:
if not op.has_attr('is_test'):
continue
op._set_attr('is_test', is_test)
...@@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception): ...@@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception):
class LocalModuleManager(object): class LocalModuleManager(object):
""" '''
LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation, LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation,
and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the
LocalModuleManager object of the same home directory is generated multiple times, the same object is returned. LocalModuleManager object of the same home directory is generated multiple times, the same object is returned.
Args: Args:
home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules
""" '''
_instance_map = {} _instance_map = {}
def __new__(cls, home: str = MODULE_HOME): def __new__(cls, home: str = MODULE_HOME):
......
...@@ -17,9 +17,9 @@ import inspect ...@@ -17,9 +17,9 @@ import inspect
import importlib import importlib
import os import os
import sys import sys
from typing import Callable, List, Optional, Generic from typing import Callable, Generic, List, Optional
from paddlehub.utils import utils from paddlehub.utils import log, utils
from paddlehub.compat.module.module_v1 import ModuleV1 from paddlehub.compat.module.module_v1 import ModuleV1
...@@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable: ...@@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable:
class Module(object): class Module(object):
''' '''
''' '''
def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs): def __new__(cls, name: str = None, directory: str = None, version: str = None, **kwargs):
if cls.__name__ == 'Module': if cls.__name__ == 'Module':
# This branch come from hub.Module(name='xxx') # This branch come from hub.Module(name='xxx') or hub.Module(directory='xxx')
if name: if name:
module = cls.init_with_name(name=name, version=version, **kwargs) module = cls.init_with_name(name=name, version=version, **kwargs)
elif directory: elif directory:
...@@ -72,19 +73,19 @@ class Module(object): ...@@ -72,19 +73,19 @@ class Module(object):
@classmethod @classmethod
def load(cls, directory: str) -> Generic: def load(cls, directory: str) -> Generic:
'''
'''
if directory.endswith(os.sep): if directory.endswith(os.sep):
directory = directory[:-1] directory = directory[:-1]
# if module description file existed, try to load as ModuleV1 # If module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file): if os.path.exists(desc_file):
return ModuleV1.load(desc_file) return ModuleV1.load(desc_file)
basename = os.path.split(directory)[-1] basename = os.path.split(directory)[-1]
dirname = os.path.join(*list(os.path.split(directory)[:-1])) dirname = os.path.join(*list(os.path.split(directory)[:-1]))
py_module = utils.load_py_module(dirname, '{}.module'.format(basename))
sys.path.insert(0, dirname)
py_module = importlib.import_module('{}.module'.format(basename))
for _item, _cls in inspect.getmembers(py_module, inspect.isclass): for _item, _cls in inspect.getmembers(py_module, inspect.isclass):
_item = py_module.__dict__[_item] _item = py_module.__dict__[_item]
...@@ -93,13 +94,14 @@ class Module(object): ...@@ -93,13 +94,14 @@ class Module(object):
break break
else: else:
raise InvalidHubModule(directory) raise InvalidHubModule(directory)
sys.path.pop(0)
user_module_cls.directory = directory user_module_cls.directory = directory
return user_module_cls return user_module_cls
@classmethod @classmethod
def init_with_name(cls, name: str, version: str = None, **kwargs): def init_with_name(cls, name: str, version: str = None, **kwargs):
'''
'''
from paddlehub.module.manager import LocalModuleManager from paddlehub.module.manager import LocalModuleManager
manager = LocalModuleManager() manager = LocalModuleManager()
user_module_cls = manager.search(name) user_module_cls = manager.search(name)
...@@ -107,15 +109,39 @@ class Module(object): ...@@ -107,15 +109,39 @@ class Module(object):
user_module_cls = manager.install(name, version) user_module_cls = manager.install(name, version)
directory = manager._get_normalized_path(name) directory = manager._get_normalized_path(name)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if hasattr(user_module_cls, '_initialize'):
log.logger.warning(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs) return user_module_cls(directory=directory, **kwargs)
@classmethod @classmethod
def init_with_directory(cls, directory: str, **kwargs): def init_with_directory(cls, directory: str, **kwargs):
'''
'''
user_module_cls = cls.load(directory) user_module_cls = cls.load(directory)
return user_module_cls(**kwargs)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if hasattr(user_module_cls, '_initialize'):
log.logger.warning(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module = user_module_cls(directory=directory)
user_module._initialize(**kwargs)
return user_module
return user_module_cls(directory=directory, **kwargs)
@classmethod @classmethod
def get_py_requirements(cls): def get_py_requirements(cls):
'''
'''
req_file = os.path.join(cls.directory, 'requirements.txt') req_file = os.path.join(cls.directory, 'requirements.txt')
if not os.path.exists(req_file): if not os.path.exists(req_file):
return [] return []
...@@ -125,6 +151,9 @@ class Module(object): ...@@ -125,6 +151,9 @@ class Module(object):
class RunModule(object): class RunModule(object):
'''
'''
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Avoid module being initialized multiple times # Avoid module being initialized multiple times
if '_is_initialize' in self.__dict__ and self._is_initialize: if '_is_initialize' in self.__dict__ and self._is_initialize:
...@@ -149,6 +178,8 @@ class RunModule(object): ...@@ -149,6 +178,8 @@ class RunModule(object):
@classmethod @classmethod
def get_py_requirements(cls) -> List[str]: def get_py_requirements(cls) -> List[str]:
'''
'''
py_module = sys.modules[cls.__module__] py_module = sys.modules[cls.__module__]
directory = os.path.dirname(py_module.__file__) directory = os.path.dirname(py_module.__file__)
req_file = os.path.join(directory, 'requirements.txt') req_file = os.path.join(directory, 'requirements.txt')
...@@ -172,6 +203,9 @@ def moduleinfo(name: str, ...@@ -172,6 +203,9 @@ def moduleinfo(name: str,
summary: str = None, summary: str = None,
type: str = None, type: str = None,
meta=None) -> Callable: meta=None) -> Callable:
'''
'''
def _wrapper(cls: Generic) -> Generic: def _wrapper(cls: Generic) -> Generic:
wrap_cls = cls wrap_cls = cls
_meta = RunModule if not meta else meta _meta = RunModule if not meta else meta
......
...@@ -170,8 +170,8 @@ class FormattedText(object): ...@@ -170,8 +170,8 @@ class FormattedText(object):
self.width = width self.width = width
def __repr__(self) -> str: def __repr__(self) -> str:
form = ':{}{}'.format(self.align, self.width) form = '{{:{}{}}}'.format(self.align, self.width)
text = ('{' + form + '}').format(self.text) text = form.format(self.text)
if not self.color: if not self.color:
return text return text
return self.color + text + Fore.RESET return self.color + text + Fore.RESET
......
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import sys
from typing import List
import yaml
from paddlehub.utils.utils import sys_stdin_encoding
class CSVFileParser(object):
def parse(self, csv_file: str) -> dict:
with codecs.open(csv_file, 'r', sys_stdin_encoding()) as file:
content = file.read()
content = content.split('\n')
self.title = content[0].split(',')
self.content = {}
for key in self.title:
self.content[key] = []
for text in content[1:]:
if (text == ''):
continue
for index, item in enumerate(text.split(',')):
title = self.title[index]
self.content[title].append(item)
return self.content
class YAMLFileParser(object):
def parse(self, yaml_file: str) -> dict:
with codecs.open(yaml_file, 'r', sys_stdin_encoding()) as file:
content = file.read()
return yaml.load(content, Loader=yaml.BaseLoader)
class TextFileParser(object):
def parse(self, txt_file: str, use_strip: bool = True) -> List:
contents = []
try:
with codecs.open(txt_file, 'r', encoding='utf8') as file:
for line in file:
if use_strip:
line = line.strip()
if line:
contents.append(line)
except:
with codecs.open(txt_file, 'r', encoding='gbk') as file:
for line in file:
if use_strip:
line = line.strip()
if line:
contents.append(line)
return contents
csv_parser = CSVFileParser()
yaml_parser = YAMLFileParser()
txt_parser = TextFileParser()
...@@ -31,10 +31,12 @@ from urllib.parse import urlparse ...@@ -31,10 +31,12 @@ from urllib.parse import urlparse
import packaging.version import packaging.version
import paddlehub.env as hubenv import paddlehub.env as hubenv
import paddlehub.utils as utils
class Version(packaging.version.Version): class Version(packaging.version.Version):
'''Extended implementation of packaging.version.Version''' '''Extended implementation of packaging.version.Version'''
def match(self, condition: str) -> bool: def match(self, condition: str) -> bool:
''' '''
Determine whether the given condition are met Determine whether the given condition are met
...@@ -76,9 +78,35 @@ class Version(packaging.version.Version): ...@@ -76,9 +78,35 @@ class Version(packaging.version.Version):
return _comp(Version(version)) return _comp(Version(version))
def __lt__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__lt__(other)
def __le__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__le__(other)
def __gt__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__gt__(other)
def __ge__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__ge__(other)
def __eq__(self, other):
if isinstance(other, str):
other = Version(other)
return super().__eq__(other)
class Timer(object): class Timer(object):
'''Calculate runing speed and estimated time of arrival(ETA)''' '''Calculate runing speed and estimated time of arrival(ETA)'''
def __init__(self, total_step: int): def __init__(self, total_step: int):
self.total_step = total_step self.total_step = total_step
self.last_start_step = 0 self.last_start_step = 0
...@@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType: ...@@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
sys.path.pop(0) sys.path.pop(0)
return py_module return py_module
def get_platform_default_encoding() -> str:
'''
'''
if utils.platform.is_windows():
return 'gbk'
return 'utf8'
def sys_stdin_encoding() -> str:
'''
'''
encoding = sys.stdin.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
def sys_stdout_encoding() -> str:
'''
'''
encoding = sys.stdout.encoding
if encoding is None:
encoding = sys.getdefaultencoding()
if encoding is None:
encoding = get_platform_default_encoding()
return encoding
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册