提交 f39dbbfe 编写于 作者: W wuzewu

Fix module compatibility issues

上级 c380150e
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
import sys import sys
from easydict import EasyDict
__version__ = '2.0.0a0' __version__ = '2.0.0a0'
from paddlehub.utils import log, parser, utils from paddlehub.utils import log, parser, utils
from paddlehub.module import Module from paddlehub.module import Module
# In order to maintain the compatibility of the old version, we put the relevant # In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub/compat package, and mapped some modules referenced # compatible code in the paddlehub/compat package, and mapped some modules referenced
# in the old version # in the old version
...@@ -26,8 +29,12 @@ from paddlehub.compat import paddle_utils ...@@ -26,8 +29,12 @@ from paddlehub.compat import paddle_utils
from paddlehub.compat.module.processor import BaseProcessor from paddlehub.compat.module.processor import BaseProcessor
from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule from paddlehub.compat.module.nlp_module import NLPPredictionModule, TransformerModule
from paddlehub.compat.type import DataType from paddlehub.compat.type import DataType
from paddlehub.compat import task
sys.modules['paddlehub.io.parser'] = parser sys.modules['paddlehub.io.parser'] = parser
sys.modules['paddlehub.common.logger'] = log sys.modules['paddlehub.common.logger'] = log
sys.modules['paddlehub.common.paddle_helper'] = paddle_utils sys.modules['paddlehub.common.paddle_helper'] = paddle_utils
sys.modules['paddlehub.common.utils'] = utils sys.modules['paddlehub.common.utils'] = utils
sys.modules['paddlehub.reader'] = task
common = EasyDict(paddle_helper=paddle_utils)
...@@ -177,16 +177,21 @@ class ModuleV1(object): ...@@ -177,16 +177,21 @@ class ModuleV1(object):
return [] return []
@classmethod @classmethod
def load(cls, desc_file): def load(cls, directory: str) -> EasyDict:
desc = module_v1_utils.convert_module_desc(desc_file) module_info = cls.load_module_info(directory)
cls.name = module_info.name
cls.author = desc.module_info.author cls.author = module_info.author
cls.author_email = desc.module_info.author_email cls.author_email = module_info.author_email
cls.summary = desc.module_info.summary cls.type = module_info.type
cls.type = desc.module_info.type cls.summary = module_info.summary
cls.name = desc.module_info.name cls.version = utils.Version(module_info.version)
cls.version = utils.Version(desc.module_info.version)
return cls return cls
@classmethod
def load_module_info(cls, directory: str) -> EasyDict:
desc_file = os.path.join(directory, 'module_desc.pb')
desc = module_v1_utils.convert_module_desc(desc_file)
return desc.module_info
def assets_path(self): def assets_path(self):
return os.path.join(self.directory, 'assets') return os.path.join(self.directory, 'assets')
...@@ -31,13 +31,11 @@ def convert_module_desc(desc_file): ...@@ -31,13 +31,11 @@ def convert_module_desc(desc_file):
def convert_signatures(signmaps): def convert_signatures(signmaps):
_dict = EasyDict() _dict = EasyDict()
for sign, var in signmaps.items(): for sign, var in signmaps.items():
_dict[sign] = EasyDict() _dict[sign] = EasyDict(inputs=[], outputs=[])
for fetch_var in var.fetch_desc: for fetch_var in var.fetch_desc:
_dict[sign].outputs = list()
_dict[sign].outputs.append(EasyDict(name=fetch_var.var_name, alias=fetch_var.alias)) _dict[sign].outputs.append(EasyDict(name=fetch_var.var_name, alias=fetch_var.alias))
for feed_var in var.feed_desc: for feed_var in var.feed_desc:
_dict[sign].inputs = list()
_dict[sign].inputs.append(EasyDict(name=feed_var.var_name, alias=feed_var.alias)) _dict[sign].inputs.append(EasyDict(name=feed_var.var_name, alias=feed_var.alias))
return _dict return _dict
......
...@@ -22,7 +22,6 @@ from typing import Any, List, Text, Tuple ...@@ -22,7 +22,6 @@ from typing import Any, List, Text, Tuple
import paddle import paddle
import numpy as np import numpy as np
from paddle.fluid.core import PaddleTensor, AnalysisConfig, create_paddle_predictor
from paddlehub.compat import paddle_utils from paddlehub.compat import paddle_utils
from paddlehub.compat.task.transformer_emb_task import TransformerEmbeddingTask from paddlehub.compat.task.transformer_emb_task import TransformerEmbeddingTask
...@@ -51,10 +50,10 @@ class NLPBaseModule(RunModule): ...@@ -51,10 +50,10 @@ class NLPBaseModule(RunModule):
class NLPPredictionModule(NLPBaseModule): class NLPPredictionModule(NLPBaseModule):
def _set_config(self): def _set_config(self):
'''predictor config setting''' '''predictor config setting'''
cpu_config = AnalysisConfig(self.pretrained_model_path) cpu_config = paddle.device.core.AnalysisConfig(self.pretrained_model_path)
cpu_config.disable_glog_info() cpu_config.disable_glog_info()
cpu_config.disable_gpu() cpu_config.disable_gpu()
self.cpu_predictor = create_paddle_predictor(cpu_config) self.cpu_predictor = paddle.device.core.create_paddle_predictor(cpu_config)
try: try:
_places = os.environ['CUDA_VISIBLE_DEVICES'] _places = os.environ['CUDA_VISIBLE_DEVICES']
...@@ -63,10 +62,10 @@ class NLPPredictionModule(NLPBaseModule): ...@@ -63,10 +62,10 @@ class NLPPredictionModule(NLPBaseModule):
except: except:
use_gpu = False use_gpu = False
if use_gpu: if use_gpu:
gpu_config = AnalysisConfig(self.pretrained_model_path) gpu_config = paddle.device.core.AnalysisConfig(self.pretrained_model_path)
gpu_config.disable_glog_info() gpu_config.disable_glog_info()
gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0) gpu_config.enable_use_gpu(memory_pool_init_size_mb=500, device_id=0)
self.gpu_predictor = create_paddle_predictor(gpu_config) self.gpu_predictor = paddle.device.core.create_paddle_predictor(gpu_config)
def texts2tensor(self, texts: List[dict]) -> paddle.Tensor: def texts2tensor(self, texts: List[dict]) -> paddle.Tensor:
''' '''
...@@ -82,7 +81,7 @@ class NLPPredictionModule(NLPBaseModule): ...@@ -82,7 +81,7 @@ class NLPPredictionModule(NLPBaseModule):
for i, text in enumerate(texts): for i, text in enumerate(texts):
data += text['processed'] data += text['processed']
lod.append(len(text['processed']) + lod[i]) lod.append(len(text['processed']) + lod[i])
tensor = PaddleTensor(np.array(data).astype('int64')) tensor = paddle.device.core.PaddleTensor(np.array(data).astype('int64'))
tensor.name = 'words' tensor.name = 'words'
tensor.lod = [lod] tensor.lod = [lod]
tensor.shape = [lod[-1], 1] tensor.shape = [lod[-1], 1]
...@@ -183,7 +182,7 @@ class TransformerModule(NLPBaseModule): ...@@ -183,7 +182,7 @@ class TransformerModule(NLPBaseModule):
assert os.path.exists(pretraining_params_path), '[{}] cann\'t be found.'.format(pretraining_params_path) assert os.path.exists(pretraining_params_path), '[{}] cann\'t be found.'.format(pretraining_params_path)
def existed_params(var): def existed_params(var):
if not isinstance(var, paddle.fluid.framework.Parameter): if not isinstance(var, paddle.device.framework.Parameter):
return False return False
return os.path.exists(os.path.join(pretraining_params_path, var.name)) return os.path.exists(os.path.join(pretraining_params_path, var.name))
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
import shutil import shutil
import sys
from collections import OrderedDict from collections import OrderedDict
from typing import List from typing import List
...@@ -65,11 +66,18 @@ class LocalModuleManager(object): ...@@ -65,11 +66,18 @@ class LocalModuleManager(object):
self.home = home self.home = home
self._local_modules = OrderedDict() self._local_modules = OrderedDict()
# Most HubModule can be regarded as a python package, so we need to add the home
# directory to sys.path
if not home in sys.path:
sys.path.insert(0, home)
def _get_normalized_path(self, name: str) -> str: def _get_normalized_path(self, name: str) -> str:
return os.path.join(self.home, self._get_normalized_name(name))
def _get_normalized_name(self, name: str) -> str:
# Some HubModules contain '-' in name (eg roberta_wwm_ext_chinese_L-3_H-1024_A-16). # Some HubModules contain '-' in name (eg roberta_wwm_ext_chinese_L-3_H-1024_A-16).
# Replace '-' with '_' to comply with python naming conventions. # Replace '-' with '_' to comply with python naming conventions.
name = name.replace('-', '_') return name.replace('-', '_')
return os.path.join(self.home, name)
def install(self, def install(self,
name: str = None, name: str = None,
...@@ -194,25 +202,41 @@ class LocalModuleManager(object): ...@@ -194,25 +202,41 @@ class LocalModuleManager(object):
def _install_from_directory(self, directory: str) -> HubModule: def _install_from_directory(self, directory: str) -> HubModule:
'''Install a HubModule from directory containing module.py''' '''Install a HubModule from directory containing module.py'''
hub_module_cls = HubModule.load(directory) module_info = HubModule.load_module_info(directory)
# Uninstall local module # A temporary directory is copied here for two purposes:
if self.search(hub_module_cls.name): # 1. Avoid affecting user-specified directory (for example, a __pycache__
self.uninstall(hub_module_cls.name) # directory will be generated).
# 2. HubModule is essentially a python package. When internal package
shutil.copytree(directory, os.path.join(self.home, hub_module_cls.name)) # references are made in it, the correct package name is required.
self._local_modules[hub_module_cls.name] = hub_module_cls with utils.generate_tempdir() as _dir:
tempdir = os.path.join(_dir, module_info.name)
for py_req in hub_module_cls.get_py_requirements(): tempdir = self._get_normalized_name(tempdir)
log.logger.info('Installing dependent packages: {}'.format(py_req)) shutil.copytree(directory, tempdir)
result = pypi.install(py_req)
if result: directory = tempdir
log.logger.info('Successfully installed {}'.format(py_req)) hub_module_cls = HubModule.load(directory)
else:
log.logger.info('Some errors occurred while installing {}'.format(py_req)) # Uninstall local module
if self.search(hub_module_cls.name):
self.uninstall(hub_module_cls.name)
shutil.copytree(directory, self._get_normalized_path(hub_module_cls.name))
# Reload the Module object to avoid path errors
hub_module_cls = HubModule.load(self._get_normalized_path(hub_module_cls.name))
self._local_modules[hub_module_cls.name] = hub_module_cls
for py_req in hub_module_cls.get_py_requirements():
log.logger.info('Installing dependent packages: {}'.format(py_req))
result = pypi.install(py_req)
if result:
log.logger.info('Successfully installed {}'.format(py_req))
else:
log.logger.info('Some errors occurred while installing {}'.format(py_req))
log.logger.info('Successfully installed {}-{}'.format(hub_module_cls.name, hub_module_cls.version)) log.logger.info('Successfully installed {}-{}'.format(hub_module_cls.name, hub_module_cls.version))
return hub_module_cls return hub_module_cls
def _install_from_archive(self, archive: str) -> HubModule: def _install_from_archive(self, archive: str) -> HubModule:
'''Install HubModule from archive file (eg xxx.tar.gz)''' '''Install HubModule from archive file (eg xxx.tar.gz)'''
......
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import ast
import inspect import inspect
import importlib import importlib
import os import os
import sys import sys
from typing import Callable, Generic, List, Optional from typing import Callable, Generic, List, Optional
from easydict import EasyDict
from paddlehub.utils import log, utils from paddlehub.utils import log, utils
from paddlehub.compat.module.module_v1 import ModuleV1 from paddlehub.compat.module.module_v1 import ModuleV1
...@@ -81,7 +84,7 @@ class Module(object): ...@@ -81,7 +84,7 @@ class Module(object):
# If module description file existed, try to load as ModuleV1 # If module description file existed, try to load as ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb') desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file): if os.path.exists(desc_file):
return ModuleV1.load(desc_file) return ModuleV1.load(directory)
basename = os.path.split(directory)[-1] basename = os.path.split(directory)[-1]
dirname = os.path.join(*list(os.path.split(directory)[:-1])) dirname = os.path.join(*list(os.path.split(directory)[:-1]))
...@@ -98,6 +101,32 @@ class Module(object): ...@@ -98,6 +101,32 @@ class Module(object):
user_module_cls.directory = directory user_module_cls.directory = directory
return user_module_cls return user_module_cls
@classmethod
def load_module_info(cls, directory: str) -> EasyDict:
# If is ModuleV1
desc_file = os.path.join(directory, 'module_desc.pb')
if os.path.exists(desc_file):
return ModuleV1.load_module_info(directory)
# If is ModuleV2
module_file = os.path.join(directory, 'module.py')
with open(module_file, 'r') as file:
pycode = file.read()
ast_module = ast.parse(pycode)
for _body in ast_module.body:
if not isinstance(_body, ast.ClassDef):
continue
for _decorator in _body.decorator_list:
if _decorator.func.id != 'moduleinfo':
continue
info = {key.arg: key.value.s for key in _decorator.keywords}
return EasyDict(info)
else:
raise InvalidHubModule(directory)
@classmethod @classmethod
def init_with_name(cls, name: str, version: str = None, **kwargs): def init_with_name(cls, name: str, version: str = None, **kwargs):
''' '''
...@@ -108,7 +137,7 @@ class Module(object): ...@@ -108,7 +137,7 @@ class Module(object):
if not user_module_cls or not user_module_cls.version.match(version): if not user_module_cls or not user_module_cls.version.match(version):
user_module_cls = manager.install(name, version) user_module_cls = manager.install(name, version)
directory = manager._get_normalized_path(name) directory = manager._get_normalized_path(user_module_cls.name)
# The HubModule in the old version will use the _initialize method to initialize, # The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version # this function will be obsolete in a future version
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册