diff --git a/paddlehub/__init__.py b/paddlehub/__init__.py index b940e22925c18fb6a050a85e5270331a73d03510..e511d44dd1e399a9b0da56992c13064b1a127311 100644 --- a/paddlehub/__init__.py +++ b/paddlehub/__init__.py @@ -38,7 +38,7 @@ from .common.logger import logger from .common.paddle_helper import connect_program from .common.hub_server import default_hub_server -from .module.module import Module, create_module +from .module.module import Module from .module.base_processor import BaseProcessor from .module.signature import Signature, create_signature from .module.manager import default_module_manager diff --git a/paddlehub/commands/install.py b/paddlehub/commands/install.py index e9ba9ba46fca22106736699675ca9824a857916b..1e6407a49b79d0594c37f5e202d30675f328ed5f 100644 --- a/paddlehub/commands/install.py +++ b/paddlehub/commands/install.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import argparse +import os from paddlehub.common import utils from paddlehub.module.manager import default_module_manager @@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): print("ERROR: Please specify a module name.\n") self.help() return False - module_name = argv[0] - module_version = None if "==" not in module_name else module_name.split( - "==")[1] - module_name = module_name if "==" not in module_name else module_name.split( - "==")[0] extra = {"command": "install"} - result, tips, module_dir = default_module_manager.install_module( - module_name=module_name, module_version=module_version, extra=extra) + if argv[0].endswith("tar.gz") or argv[0].endswith("phm"): + result, tips, module_dir = default_module_manager.install_module( + module_package=argv[0], extra=extra) + elif os.path.exists(argv[0]) and os.path.isdir(argv[0]): + result, tips, module_dir = default_module_manager.install_module( + module_dir=argv[0], extra=extra) + else: + module_name = argv[0] + module_version = None if "==" not in module_name else module_name.split( + "==")[1] + module_name = module_name if "==" not in module_name else module_name.split( + "==")[0] + result, tips, module_dir = default_module_manager.install_module( + module_name=module_name, + module_version=module_version, + extra=extra) print(tips) return True diff --git a/paddlehub/commands/run.py b/paddlehub/commands/run.py index 30876bc7ca5dcb84b3712fbafd905539e9ff2c5e..9754d27f2d4ae862f0d3e4ad44bce63ab7bbadc5 100644 --- a/paddlehub/commands/run.py +++ b/paddlehub/commands/run.py @@ -71,7 +71,7 @@ class RunCommand(BaseCommand): if not result: return None - return hub.Module(module_dir=module_dir) + return hub.Module(directory=module_dir[0]) def add_module_config_arg(self): configs = self.module.processor.configs() @@ -105,7 +105,7 @@ class RunCommand(BaseCommand): def add_module_input_arg(self): module_type = self.module.type.lower() expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) self.arg_input_group.add_argument( '--input_file', type=str, @@ -152,7 +152,7 @@ class RunCommand(BaseCommand): def get_data(self): module_type = self.module.type.lower() expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) input_data = {} if len(expect_data_format) == 1: key = list(expect_data_format.keys())[0] @@ -177,7 +177,7 @@ class RunCommand(BaseCommand): def check_data(self, data): expect_data_format = self.module.processor.data_format( - self.module.default_signature.name) + self.module.default_signature) if len(data.keys()) != len(expect_data_format.keys()): print( @@ -236,35 +236,38 @@ class RunCommand(BaseCommand): return False # If the module is not executable, give an alarm and exit - if not self.module.default_signature: + if not self.module.is_runable: print("ERROR! Module %s is not executable." % module_name) return False - self.module.check_processor() - self.add_module_config_arg() - self.add_module_input_arg() + if self.module.code_version == "v2": + results = self.module(argv[1:]) + else: + self.module.check_processor() + self.add_module_config_arg() + self.add_module_input_arg() - if not argv[1:]: - self.help() - return False + if not argv[1:]: + self.help() + return False - self.args = self.parser.parse_args(argv[1:]) + self.args = self.parser.parse_args(argv[1:]) - config = self.get_config() - data = self.get_data() + config = self.get_config() + data = self.get_data() - try: - self.check_data(data) - except DataFormatError: - self.help() - return False - - results = self.module( - sign_name=self.module.default_signature.name, - data=data, - use_gpu=self.args.use_gpu, - batch_size=self.args.batch_size, - **config) + try: + self.check_data(data) + except DataFormatError: + self.help() + return False + + results = self.module( + sign_name=self.module.default_signature, + data=data, + use_gpu=self.args.use_gpu, + batch_size=self.args.batch_size, + **config) if six.PY2: try: diff --git a/paddlehub/commands/show.py b/paddlehub/commands/show.py index e160da0210556072450cee679b5f2c9a16fd5f5a..6e54d0863a39353423d73ec8724a627777cfc6af 100644 --- a/paddlehub/commands/show.py +++ b/paddlehub/commands/show.py @@ -125,8 +125,6 @@ class ShowCommand(BaseCommand): cwd = os.getcwd() module_dir = default_module_manager.search_module(module_name) - module_dir = (os.path.join(cwd, module_name), - None) if not module_dir else module_dir if not module_dir or not os.path.exists(module_dir[0]): print("%s is not existed!" % module_name) return True diff --git a/paddlehub/module/check_info.proto b/paddlehub/module/check_info.proto index 923de58cefe8bffd6499378733e29bbb2e7a508f..56c1b584de7afcd958eb3edaffc0fdef8b0d7363 100644 --- a/paddlehub/module/check_info.proto +++ b/paddlehub/module/check_info.proto @@ -50,6 +50,7 @@ message CheckInfo { string paddle_version = 1; string hub_version = 2; string module_proto_version = 3; - repeated FileInfo file_infos = 4; - repeated Requires requires = 5; + string module_code_version = 4; + repeated FileInfo file_infos = 5; + repeated Requires requires = 6; }; diff --git a/paddlehub/module/check_info_pb2.py b/paddlehub/module/check_info_pb2.py index 78f5546c49c417508d26fa0f809340459987fc66..8ed17a9ac532ad5bd7a7242d27793ca53a235b40 100644 --- a/paddlehub/module/check_info_pb2.py +++ b/paddlehub/module/check_info_pb2.py @@ -1,4 +1,3 @@ -#coding:utf-8 # Generated by the protocol buffer compiler. DO NOT EDIT! # source: check_info.proto @@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( package='paddlehub.module.checkinfo', syntax='proto3', serialized_pb=_b( - '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' + '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' )) _sym_db.RegisterFileDescriptor(DESCRIPTOR) @@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=522, - serialized_end=552, + serialized_start=551, + serialized_end=581, ) _sym_db.RegisterEnumDescriptor(_FILE_TYPE) @@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor( ], containing_type=None, options=None, - serialized_start=554, - serialized_end=645, + serialized_start=583, + serialized_end=674, ) _sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE) @@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor( extension_scope=None, options=None), _descriptor.FieldDescriptor( - name='file_infos', - full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', + name='module_code_version', + full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version', index=3, number=4, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='file_infos', + full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', + index=4, + number=5, type=11, cpp_type=10, label=3, @@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor( _descriptor.FieldDescriptor( name='requires', full_name='paddlehub.module.checkinfo.CheckInfo.requires', - index=4, - number=5, + index=5, + number=6, type=11, cpp_type=10, label=3, @@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor( extension_ranges=[], oneofs=[], serialized_start=320, - serialized_end=520, + serialized_end=549, ) _FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE diff --git a/paddlehub/module/checker.py b/paddlehub/module/checker.py index d76ca6bd74ab4d354ba5fa72e8f1e7215c0ed6f0..b1470af4d16774688af53a0a7293691d30bc3e6c 100644 --- a/paddlehub/module/checker.py +++ b/paddlehub/module/checker.py @@ -32,20 +32,22 @@ FILE_SEP = "/" class ModuleChecker(object): - def __init__(self, module_path): - self.module_path = module_path + def __init__(self, directory): + self._directory = directory + self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME) def generate_check_info(self): check_info = check_info_pb2.CheckInfo() check_info.paddle_version = paddle.__version__ check_info.hub_version = hub_version check_info.module_proto_version = module_proto_version + check_info.module_code_version = "v2" file_infos = check_info.file_infos - file_list = [file for file in os.listdir(self.module_path)] + file_list = [file for file in os.listdir(self.directory)] while file_list: file = file_list[0] file_list = file_list[1:] - abs_path = os.path.join(self.module_path, file) + abs_path = os.path.join(self.directory, file) if os.path.isdir(abs_path): for sub_file in os.listdir(abs_path): sub_file = os.path.join(file, sub_file) @@ -62,9 +64,12 @@ class ModuleChecker(object): file_info.type = check_info_pb2.FILE file_info.is_need = True - with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME), - "wb") as fi: - fi.write(check_info.SerializeToString()) + with open(self.pb_path, "wb") as file: + file.write(check_info.SerializeToString()) + + @property + def module_code_version(self): + return self.check_info.module_code_version @property def module_proto_version(self): @@ -82,20 +87,25 @@ class ModuleChecker(object): def file_infos(self): return self.check_info.file_infos + @property + def directory(self): + return self._directory + + @property + def pb_path(self): + return self._pb_path + def check(self): result = True - self.check_info_pb_path = os.path.join(self.module_path, - CHECK_INFO_PB_FILENAME) - if not (os.path.exists(self.check_info_pb_path) - or os.path.isfile(self.check_info_pb_path)): + if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)): logger.warning( "This module lacks core file %s" % CHECK_INFO_PB_FILENAME) result = False self.check_info = check_info_pb2.CheckInfo() try: - with open(self.check_info_pb_path, "rb") as fi: + with open(self.pb_path, "rb") as fi: pb_string = fi.read() result = self.check_info.ParseFromString(pb_string) if len(pb_string) == 0 or (result is not None @@ -182,7 +192,7 @@ class ModuleChecker(object): for file_info in self.file_infos: file_type = file_info.type file_path = file_info.file_name.replace(FILE_SEP, os.sep) - file_path = os.path.join(self.module_path, file_path) + file_path = os.path.join(self.directory, file_path) if not os.path.exists(file_path): if file_info.is_need: logger.warning( diff --git a/paddlehub/module/manager.py b/paddlehub/module/manager.py index e8f5d653f6ebee8c9a9f07af657d471d99057c4f..9e6caf182a17ab6cf95c2739578ee821951679b1 100644 --- a/paddlehub/module/manager.py +++ b/paddlehub/module/manager.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import shutil +import tarfile from paddlehub.common import utils from paddlehub.common import srv_utils @@ -77,15 +78,76 @@ class LocalModuleManager(object): return self.modules_dict.get(module_name, None) def install_module(self, - module_name, + module_name=None, + module_dir=None, + module_package=None, module_version=None, upgrade=False, extra=None): - self.all_modules(update=True) - module_info = self.modules_dict.get(module_name, None) - if module_info: - if not module_version or module_version == self.modules_dict[ - module_name][1]: + md5_value = installed_module_version = None + from_user_dir = True if module_dir else False + if module_name: + self.all_modules(update=True) + module_info = self.modules_dict.get(module_name, None) + if module_info: + if not module_version or module_version == self.modules_dict[ + module_name][1]: + module_dir = self.modules_dict[module_name][0] + module_tag = module_name if not module_version else '%s-%s' % ( + module_name, module_version) + tips = "Module %s already installed in %s" % (module_tag, + module_dir) + return True, tips, self.modules_dict[module_name] + + search_result = hub.default_hub_server.get_module_url( + module_name, version=module_version, extra=extra) + name = search_result.get('name', None) + url = search_result.get('url', None) + md5_value = search_result.get('md5', None) + installed_module_version = search_result.get('version', None) + if not url or (module_version is not None + and installed_module_version != module_version) or ( + name != module_name): + if default_hub_server._server_check() is False: + tips = "Request Hub-Server unsuccessfully, please check your network." + else: + tips = "Can't find module %s" % module_name + if module_version: + tips += " with version %s" % module_version + module_tag = module_name if not module_version else '%s-%s' % ( + module_name, module_version) + return False, tips, None + + result, tips, module_zip_file = default_downloader.download_file( + url=url, + save_path=hub.CACHE_HOME, + save_name=module_name, + replace=True, + print_progress=True) + result, tips, module_dir = default_downloader.uncompress( + file=module_zip_file, + dirname=MODULE_HOME, + delete_file=True, + print_progress=True) + + if module_package: + with tarfile.open(module_package, "r:gz") as tar: + file_names = tar.getnames() + size = len(file_names) - 1 + module_dir = os.path.split(file_names[0])[0] + module_dir = os.path.join(hub.CACHE_HOME, module_dir) + # remove cache + if os.path.exists(module_dir): + shutil.rmtree(module_dir) + for index, file_name in enumerate(file_names): + tar.extract(file_name, hub.CACHE_HOME) + + if module_dir: + if not module_name: + module_name = hub.Module(directory=module_dir).name + self.all_modules(update=False) + module_info = self.modules_dict.get(module_name, None) + if module_info: module_dir = self.modules_dict[module_name][0] module_tag = module_name if not module_version else '%s-%s' % ( module_name, module_version) @@ -93,44 +155,18 @@ class LocalModuleManager(object): module_dir) return True, tips, self.modules_dict[module_name] - search_result = hub.default_hub_server.get_module_url( - module_name, version=module_version, extra=extra) - name = search_result.get('name', None) - url = search_result.get('url', None) - md5_value = search_result.get('md5', None) - installed_module_version = search_result.get('version', None) - if not url or (module_version is not None and installed_module_version - != module_version) or (name != module_name): - if default_hub_server._server_check() is False: - tips = "Request Hub-Server unsuccessfully, please check your network." - else: - tips = "Can't find module %s" % module_name - if module_version: - tips += " with version %s" % module_version - module_tag = module_name if not module_version else '%s-%s' % ( - module_name, module_version) - return False, tips, None - - result, tips, module_zip_file = default_downloader.download_file( - url=url, - save_path=hub.CACHE_HOME, - save_name=module_name, - replace=True, - print_progress=True) - result, tips, module_dir = default_downloader.uncompress( - file=module_zip_file, - dirname=MODULE_HOME, - delete_file=True, - print_progress=True) - - if module_dir: - with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"), - "w") as fp: - fp.write(md5_value) + if md5_value: + with open( + os.path.join(MODULE_HOME, module_dir, "md5.txt"), + "w") as fp: + fp.write(md5_value) save_path = os.path.join(MODULE_HOME, module_name) if os.path.exists(save_path): - shutil.rmtree(save_path) - shutil.move(module_dir, save_path) + shutil.move(save_path) + if from_user_dir: + shutil.copytree(module_dir, save_path) + else: + shutil.move(module_dir, save_path) module_dir = save_path tips = "Successfully installed %s" % module_name if installed_module_version: diff --git a/paddlehub/module/module.py b/paddlehub/module/module.py index 18fcec7366c11042a305ca93ac6ea2d25b3a81a0..5260e108b13c637adecc5a02edfe055f3966a1ec 100644 --- a/paddlehub/module/module.py +++ b/paddlehub/module/module.py @@ -21,6 +21,10 @@ import os import time import sys import functools +import inspect +import importlib +import tarfile +from collections import defaultdict from shutil import copyfile import paddle @@ -28,22 +32,19 @@ import paddle.fluid as fluid from paddlehub.common import utils from paddlehub.common import paddle_helper -from paddlehub.common.logger import logger +from paddlehub.common.dir import CACHE_HOME from paddlehub.common.lock import lock -from paddlehub.common.downloader import default_downloader +from paddlehub.common.logger import logger +from paddlehub.common.hub_server import CacheUpdater from paddlehub.module import module_desc_pb2 -from paddlehub.common.dir import CONF_HOME from paddlehub.module import check_info_pb2 -from paddlehub.common.hub_server import CacheUpdater -from paddlehub.module.signature import Signature, create_signature -from paddlehub.module.checker import ModuleChecker from paddlehub.module.manager import default_module_manager +from paddlehub.module.checker import ModuleChecker +from paddlehub.module.signature import Signature, create_signature from paddlehub.module.base_processor import BaseProcessor from paddlehub.io.parser import yaml_parser from paddlehub import version -__all__ = ['Module', 'create_module'] - # PaddleHub module dir name ASSETS_DIRNAME = "assets" MODEL_DIRNAME = "model" @@ -52,67 +53,226 @@ PYTHON_DIR = "python" PROCESSOR_NAME = "processor" # PaddleHub var prefix HUB_VAR_PREFIX = "@HUB_%s@" +# PaddleHub Module package suffix +HUB_PACKAGE_SUFFIX = "phm" + + +def create_module(directory, name, author, email, module_type, summary, + version): + save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX) + + # record module info and serialize + desc = module_desc_pb2.ModuleDesc() + attr = desc.attr + attr.type = module_desc_pb2.MAP + module_info = attr.map.data['module_info'] + module_info.type = module_desc_pb2.MAP + utils.from_pyobj_to_module_attr(name, module_info.map.data['name']) + utils.from_pyobj_to_module_attr(author, module_info.map.data['author']) + utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email']) + utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type']) + utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary']) + utils.from_pyobj_to_module_attr(version, module_info.map.data['version']) + + module_desc_path = os.path.join(directory, "module_desc.pb") + with open(module_desc_path, "wb") as f: + f.write(desc.SerializeToString()) + + # generate check info + checker = ModuleChecker(directory) + checker.generate_check_info() + + # add __init__ + module_init_1 = os.path.join(directory, "__init__.py") + with open(module_init_1, "a") as file: + file.write("") + + module_init_2 = os.path.join(directory, "python", "__init__.py") + with open(module_init_2, "a") as file: + file.write("") + + # package the module + with tarfile.open(save_file_name, "w:gz") as tar: + for dirname, _, files in os.walk(directory): + for file in files: + tar.add(os.path.join(dirname, file)) + + os.remove(module_desc_path) + os.remove(checker.pb_path) + os.remove(module_init_1) + os.remove(module_init_2) + + +class Module(object): + def __new__(cls, name=None, directory=None, module_dir=None, version=None): + module = None + + if cls.__name__ == "Module": + if name: + module = cls.init_with_name(name=name, version=version) + elif directory: + module = cls.init_with_directory(directory=directory) + elif module_dir: + logger.warning( + "Parameter module_dir is deprecated, please use directory to specify the path" + ) + if isinstance(module_dir, list) or isinstance( + module_dir, tuple): + directory = module_dir[0] + version = module_dir[1] + else: + directory = module_dir + module = cls.init_with_directory(directory=directory) + + if not module: + module = object.__new__(cls) + CacheUpdater(name, version).start() + return module + + def __init__(self, name=None, directory=None, module_dir=None, + version=None): + if not directory: + return + self._code_version = "v2" + self._directory = directory + self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME) + self._desc = module_desc_pb2.ModuleDesc() + with open(self.module_desc_path, "rb") as file: + self._desc.ParseFromString(file.read()) + + module_info = self.desc.attr.map.data['module_info'] + self._name = utils.from_module_attr_to_pyobj( + module_info.map.data['name']) + self._author = utils.from_module_attr_to_pyobj( + module_info.map.data['author']) + self._author_email = utils.from_module_attr_to_pyobj( + module_info.map.data['author_email']) + self._version = utils.from_module_attr_to_pyobj( + module_info.map.data['version']) + self._type = utils.from_module_attr_to_pyobj( + module_info.map.data['type']) + self._summary = utils.from_module_attr_to_pyobj( + module_info.map.data['summary']) + + self._initialize() + + @classmethod + def init_with_name(cls, name, version=None): + fp_lock = open(os.path.join(CACHE_HOME, name), "a") + lock.flock(fp_lock, lock.LOCK_EX) + log_msg = "Installing %s module" % name + if version: + log_msg += "-%s" % version + logger.info(log_msg) + extra = {"command": "install"} + result, tips, module_dir = default_module_manager.install_module( + module_name=name, module_version=version, extra=extra) + if not result: + logger.error(tips) + raise RuntimeError(tips) + logger.info(tips) + lock.flock(fp_lock, lock.LOCK_UN) + return cls.init_with_directory(directory=module_dir[0]) -def create_module(sign_arr, - module_dir, - processor=None, - assets=None, - module_info=None, - exe=None, - extra_info=None): - sign_arr = utils.to_list(sign_arr) - module = Module( - signatures=sign_arr, - processor=processor, - assets=assets, - module_info=module_info, - extra_info=extra_info) - module.serialize_to_path(path=module_dir, exe=exe) + @classmethod + def init_with_directory(cls, directory): + desc_file = os.path.join(directory, MODULE_DESC_PBNAME) + checker = ModuleChecker(directory) + checker.check() + + module_code_version = checker.module_code_version + if module_code_version == "v2": + basename = os.path.split(directory)[-1] + dirname = os.path.join(*list(os.path.split(directory)[:-1])) + sys.path.append(dirname) + pymodule = importlib.import_module( + "{}.python.module".format(basename)) + return pymodule.HubModule(directory=directory) + return ModuleV1(directory=directory) + + @property + def desc(self): + return self._desc + + @property + def directory(self): + return self._directory + + @property + def author(self): + return self._author + + @property + def author_email(self): + return self._author_email + + @property + def summary(self): + return self._summary + + @property + def type(self): + return self._type + + @property + def version(self): + return self._version + + @property + def name(self): + return self._name + + @property + def name_prefix(self): + return self._name_prefix + + @property + def code_version(self): + return self._code_version + + @property + def is_runable(self): + return False + + def _initialize(self): + pass class ModuleHelper(object): - def __init__(self, module_dir): - self.module_dir = module_dir + def __init__(self, directory): + self.directory = directory def module_desc_path(self): - return os.path.join(self.module_dir, MODULE_DESC_PBNAME) + return os.path.join(self.directory, MODULE_DESC_PBNAME) def model_path(self): - return os.path.join(self.module_dir, MODEL_DIRNAME) + return os.path.join(self.directory, MODEL_DIRNAME) def processor_path(self): - return os.path.join(self.module_dir, PYTHON_DIR) + return os.path.join(self.directory, PYTHON_DIR) def processor_name(self): return PROCESSOR_NAME def assets_path(self): - return os.path.join(self.module_dir, ASSETS_DIRNAME) + return os.path.join(self.directory, ASSETS_DIRNAME) -class Module(object): - def __init__(self, - name=None, - module_dir=None, - signatures=None, - module_info=None, - assets=None, - processor=None, - extra_info=None, +class ModuleV1(Module): + def __init__(self, name=None, directory=None, module_dir=None, version=None): - self.desc = module_desc_pb2.ModuleDesc() + if not directory: + return + super(ModuleV1, self).__init__(name, directory, module_dir, version) + self._code_version = "v1" self.program = None self.assets = [] self.helper = None self.signatures = {} self.default_signature = None - self.module_info = None self.processor = None - self.extra_info = {} if extra_info is None else extra_info - if not isinstance(self.extra_info, dict): - raise TypeError( - "The extra_info should be an instance of python dict") + self.extra_info = {} # cache data self.last_call_name = None @@ -120,62 +280,21 @@ class Module(object): self.cache_fetch_dict = None self.cache_program = None - fp_lock = open(os.path.join(CONF_HOME, 'config.json')) - lock.flock(fp_lock, lock.LOCK_EX) - if name: - self._init_with_name(name=name, version=version) - lock.flock(fp_lock, lock.LOCK_UN) - elif module_dir: - self._init_with_module_file(module_dir=module_dir[0]) - lock.flock(fp_lock, lock.LOCK_UN) - name = module_dir[0].split("/")[-1] - if len(module_dir) > 1: - version = module_dir[1] - else: - version = default_module_manager.search_module(name)[1] - elif signatures: - if processor: - if not issubclass(processor, BaseProcessor): - raise TypeError( - "Processor shoule be an instance of paddlehub.BaseProcessor" - ) - if assets: - self.assets = utils.to_list(assets) - # for asset in assets: - # utils.check_path(assets) - self.processor = processor - self._generate_module_info(module_info) - self._init_with_signature(signatures=signatures) - lock.flock(fp_lock, lock.LOCK_UN) - else: - lock.flock(fp_lock, lock.LOCK_UN) - raise ValueError("Module initialized parameter is empty") - CacheUpdater(name, version).start() - - def _init_with_name(self, name, version=None): - log_msg = "Installing %s module" % name - if version: - log_msg += "-%s" % version - logger.info(log_msg) - extra = {"command": "install"} - result, tips, module_dir = default_module_manager.install_module( - module_name=name, module_version=version, extra=extra) - if not result: - logger.error(tips) - raise RuntimeError(tips) - else: - logger.info(tips) - self._init_with_module_file(module_dir[0]) - - def _init_with_url(self, url): - utils.check_url(url) - result, tips, module_dir = default_downloader.download_file_and_uncompress( - url, save_path=".") - if not result: - logger.error(tips) - raise RuntimeError(tips) - else: - self._init_with_module_file(module_dir) + self.helper = ModuleHelper(directory) + exe = fluid.Executor(fluid.CPUPlace()) + self.program, _, _ = fluid.io.load_inference_model( + self.helper.model_path(), executor=exe) + for block in self.program.blocks: + for op in block.ops: + if "op_callstack" in op.all_attrs(): + op._set_attr("op_callstack", [""]) + self._load_processor() + self._load_assets() + self._recover_from_desc() + self._generate_sign_attr() + self._generate_extra_info() + self._restore_parameter(self.program) + self._recover_variable_info(self.program) def _dump_processor(self): import inspect @@ -216,52 +335,6 @@ class Module(object): filepath = os.path.join(self.helper.assets_path(), file) self.assets.append(filepath) - def _init_with_module_file(self, module_dir): - checker = ModuleChecker(module_dir) - checker.check() - - self.helper = ModuleHelper(module_dir) - with open(self.helper.module_desc_path(), "rb") as fi: - self.desc.ParseFromString(fi.read()) - - exe = fluid.Executor(fluid.CPUPlace()) - self.program, _, _ = fluid.io.load_inference_model( - self.helper.model_path(), executor=exe) - for block in self.program.blocks: - for op in block.ops: - if "op_callstack" in op.all_attrs(): - op._set_attr("op_callstack", [""]) - self._load_processor() - self._load_assets() - self._recover_from_desc() - self._generate_sign_attr() - self._generate_extra_info() - self._restore_parameter(self.program) - self._recover_variable_info(self.program) - - def _init_with_signature(self, signatures): - self.name_prefix = HUB_VAR_PREFIX % self.name - self._process_signatures(signatures) - self._check_signatures() - self._generate_desc() - self._generate_sign_attr() - self._generate_extra_info() - - def _init_with_program(self, program): - pass - - def _process_signatures(self, signatures): - self.signatures = {} - self.program = signatures[0].inputs[0].block.program - for sign in signatures: - if sign.name in self.signatures: - raise ValueError( - "Error! Signature array contains duplicated signatrues %s" % - sign) - if self.default_signature is None and sign.for_predict: - self.default_signature = sign - self.signatures[sign.name] = sign - def _restore_parameter(self, program): global_block = program.global_block() param_attrs = self.desc.attr.map.data['param_attrs'] @@ -302,21 +375,6 @@ class Module(object): self.__dict__["get_%s" % key] = functools.partial( self.get_extra_info, key=key) - def _generate_module_info(self, module_info=None): - if not module_info: - self.module_info = {} - else: - if not utils.is_yaml_file(module_info): - logger.critical("Module info file should be yaml format") - exit(1) - self.module_info = yaml_parser.parse(module_info) - self.author = self.module_info.get('author', 'UNKNOWN') - self.author_email = self.module_info.get('author_email', 'UNKNOWN') - self.summary = self.module_info.get('summary', 'UNKNOWN') - self.type = self.module_info.get('type', 'UNKNOWN') - self.version = self.module_info.get('version', 'UNKNOWN') - self.name = self.module_info.get('name', 'UNKNOWN') - def _generate_sign_attr(self): self._check_signatures() for sign in self.signatures: @@ -369,21 +427,21 @@ class Module(object): default_signature_name = utils.from_module_attr_to_pyobj( self.desc.attr.map.data['default_signature']) self.default_signature = self.signatures[ - default_signature_name] if default_signature_name else None + default_signature_name].name if default_signature_name else None # recover module info module_info = self.desc.attr.map.data['module_info'] - self.name = utils.from_module_attr_to_pyobj( + self._name = utils.from_module_attr_to_pyobj( module_info.map.data['name']) - self.author = utils.from_module_attr_to_pyobj( + self._author = utils.from_module_attr_to_pyobj( module_info.map.data['author']) - self.author_email = utils.from_module_attr_to_pyobj( + self._author_email = utils.from_module_attr_to_pyobj( module_info.map.data['author_email']) - self.version = utils.from_module_attr_to_pyobj( + self._version = utils.from_module_attr_to_pyobj( module_info.map.data['version']) - self.type = utils.from_module_attr_to_pyobj( + self._type = utils.from_module_attr_to_pyobj( module_info.map.data['type']) - self.summary = utils.from_module_attr_to_pyobj( + self._summary = utils.from_module_attr_to_pyobj( module_info.map.data['summary']) # recover extra info @@ -393,77 +451,9 @@ class Module(object): self.extra_info[key] = utils.from_module_attr_to_pyobj(value) # recover name prefix - self.name_prefix = utils.from_module_attr_to_pyobj( + self._name_prefix = utils.from_module_attr_to_pyobj( self.desc.attr.map.data["name_prefix"]) - def _generate_desc(self): - # save fluid Parameter - attr = self.desc.attr - attr.type = module_desc_pb2.MAP - param_attrs = attr.map.data['param_attrs'] - param_attrs.type = module_desc_pb2.MAP - for param in self.program.global_block().iter_parameters(): - param_attr = param_attrs.map.data[param.name] - paddle_helper.from_param_to_module_attr(param, param_attr) - - # save Variable Info - var_infos = attr.map.data['var_infos'] - var_infos.type = module_desc_pb2.MAP - for block in self.program.blocks: - for var in block.vars.values(): - var_info = var_infos.map.data[var.name] - var_info.type = module_desc_pb2.MAP - utils.from_pyobj_to_module_attr( - var.stop_gradient, var_info.map.data['stop_gradient']) - utils.from_pyobj_to_module_attr(block.idx, - var_info.map.data['block_id']) - - # save signarture info - for key, sign in self.signatures.items(): - var = self.desc.sign2var[sign.name] - feed_desc = var.feed_desc - fetch_desc = var.fetch_desc - feed_names = sign.feed_names - fetch_names = sign.fetch_names - for index, input in enumerate(sign.inputs): - feed_var = feed_desc.add() - feed_var.var_name = self.get_var_name_with_prefix(input.name) - feed_var.alias = feed_names[index] - - for index, output in enumerate(sign.outputs): - fetch_var = fetch_desc.add() - fetch_var.var_name = self.get_var_name_with_prefix(output.name) - fetch_var.alias = fetch_names[index] - - # save default signature - utils.from_pyobj_to_module_attr( - self.default_signature.name if self.default_signature else None, - attr.map.data['default_signature']) - - # save name prefix - utils.from_pyobj_to_module_attr(self.name_prefix, - self.desc.attr.map.data["name_prefix"]) - - # save module info - module_info = attr.map.data['module_info'] - module_info.type = module_desc_pb2.MAP - utils.from_pyobj_to_module_attr(self.name, module_info.map.data['name']) - utils.from_pyobj_to_module_attr(self.version, - module_info.map.data['version']) - utils.from_pyobj_to_module_attr(self.author, - module_info.map.data['author']) - utils.from_pyobj_to_module_attr(self.author_email, - module_info.map.data['author_email']) - utils.from_pyobj_to_module_attr(self.type, module_info.map.data['type']) - utils.from_pyobj_to_module_attr(self.summary, - module_info.map.data['summary']) - - # save extra info - extra_info = attr.map.data['extra_info'] - extra_info.type = module_desc_pb2.MAP - for key, value in self.extra_info.items(): - utils.from_pyobj_to_module_attr(value, extra_info.map.data[key]) - def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs): self.check_processor() @@ -525,6 +515,10 @@ class Module(object): if not self.processor: raise ValueError("This Module is not callable!") + @property + def is_runable(self): + return self.default_signature != None + def context(self, sign_name=None, for_test=False, @@ -664,93 +658,3 @@ class Module(object): raise ValueError( "All input and outputs variables in signature should come from the same Program" ) - - def serialize_to_path(self, path=None, exe=None): - self._check_signatures() - self._generate_desc() - # create module path for saving - if path is None: - path = os.path.join(".", self.name) - self.helper = ModuleHelper(path) - utils.mkdir(self.helper.module_dir) - - # create module pb - module_desc = module_desc_pb2.ModuleDesc() - logger.info("PaddleHub version = %s" % version.hub_version) - logger.info("PaddleHub Module proto version = %s" % - version.module_proto_version) - logger.info("Paddle version = %s" % paddle.__version__) - - feeded_var_names = [ - input.name for key, sign in self.signatures.items() - for input in sign.inputs - ] - target_vars = [ - output for key, sign in self.signatures.items() - for output in sign.outputs - ] - feeded_var_names = list(set(feeded_var_names)) - target_vars = list(set(target_vars)) - - # save inference program - program = self.program.clone() - - for block in program.blocks: - for op in block.ops: - if "op_callstack" in op.all_attrs(): - op._set_attr("op_callstack", [""]) - - if not exe: - place = fluid.CPUPlace() - exe = fluid.Executor(place=place) - utils.mkdir(self.helper.model_path()) - fluid.io.save_inference_model( - self.helper.model_path(), - feeded_var_names=list(feeded_var_names), - target_vars=list(target_vars), - main_program=program, - executor=exe) - - with open(os.path.join(self.helper.model_path(), "__model__"), - "rb") as file: - program_desc_str = file.read() - rename_program = fluid.framework.Program.parse_from_string( - program_desc_str) - varlist = { - var: block - for block in rename_program.blocks for var in block.vars - if self.get_name_prefix() not in var - } - for var, block in varlist.items(): - old_name = var - new_name = self.get_var_name_with_prefix(old_name) - block._rename_var(old_name, new_name) - utils.mkdir(self.helper.model_path()) - with open( - os.path.join(self.helper.model_path(), "__model__"), - "wb") as f: - f.write(rename_program.desc.serialize_to_string()) - - for file in os.listdir(self.helper.model_path()): - if (file == "__model__" or self.get_name_prefix() in file): - continue - os.rename( - os.path.join(self.helper.model_path(), file), - os.path.join(self.helper.model_path(), - self.get_var_name_with_prefix(file))) - - # create processor file - if self.processor: - self._dump_processor() - - # create assets - self._dump_assets() - - # create check info - checker = ModuleChecker(self.helper.module_dir) - checker.generate_check_info() - - # Serialize module_desc pb - module_pb = self.desc.SerializeToString() - with open(self.helper.module_desc_path(), "wb") as f: - f.write(module_pb)