未验证 提交 a25545de 编写于 作者: W wuzewu 提交者: GitHub

Add Module V2 support (#274)

* Add module v2
上级 b2dc77ed
...@@ -38,7 +38,7 @@ from .common.logger import logger ...@@ -38,7 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server from .common.hub_server import default_hub_server
from .module.module import Module, create_module from .module.module import Module
from .module.base_processor import BaseProcessor from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature from .module.signature import Signature, create_signature
from .module.manager import default_module_manager from .module.manager import default_module_manager
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): ...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print("ERROR: Please specify a module name.\n") print("ERROR: Please specify a module name.\n")
self.help() self.help()
return False return False
extra = {"command": "install"}
if argv[0].endswith("tar.gz") or argv[0].endswith("phm"):
result, tips, module_dir = default_module_manager.install_module(
module_package=argv[0], extra=extra)
elif os.path.exists(argv[0]) and os.path.isdir(argv[0]):
result, tips, module_dir = default_module_manager.install_module(
module_dir=argv[0], extra=extra)
else:
module_name = argv[0] module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split( module_version = None if "==" not in module_name else module_name.split(
"==")[1] "==")[1]
module_name = module_name if "==" not in module_name else module_name.split( module_name = module_name if "==" not in module_name else module_name.split(
"==")[0] "==")[0]
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module( result, tips, module_dir = default_module_manager.install_module(
module_name=module_name, module_version=module_version, extra=extra) module_name=module_name,
module_version=module_version,
extra=extra)
print(tips) print(tips)
return True return True
......
...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand): ...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if not result: if not result:
return None return None
return hub.Module(module_dir=module_dir) return hub.Module(directory=module_dir[0])
def add_module_config_arg(self): def add_module_config_arg(self):
configs = self.module.processor.configs() configs = self.module.processor.configs()
...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand): ...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self): def add_module_input_arg(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
self.arg_input_group.add_argument( self.arg_input_group.add_argument(
'--input_file', '--input_file',
type=str, type=str,
...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand): ...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def get_data(self): def get_data(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
input_data = {} input_data = {}
if len(expect_data_format) == 1: if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0] key = list(expect_data_format.keys())[0]
...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand): ...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def check_data(self, data): def check_data(self, data):
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()): if len(data.keys()) != len(expect_data_format.keys()):
print( print(
...@@ -236,10 +236,13 @@ class RunCommand(BaseCommand): ...@@ -236,10 +236,13 @@ class RunCommand(BaseCommand):
return False return False
# If the module is not executable, give an alarm and exit # If the module is not executable, give an alarm and exit
if not self.module.default_signature: if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name) print("ERROR! Module %s is not executable." % module_name)
return False return False
if self.module.code_version == "v2":
results = self.module(argv[1:])
else:
self.module.check_processor() self.module.check_processor()
self.add_module_config_arg() self.add_module_config_arg()
self.add_module_input_arg() self.add_module_input_arg()
...@@ -260,7 +263,7 @@ class RunCommand(BaseCommand): ...@@ -260,7 +263,7 @@ class RunCommand(BaseCommand):
return False return False
results = self.module( results = self.module(
sign_name=self.module.default_signature.name, sign_name=self.module.default_signature,
data=data, data=data,
use_gpu=self.args.use_gpu, use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
......
...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand): ...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd = os.getcwd() cwd = os.getcwd()
module_dir = default_module_manager.search_module(module_name) module_dir = default_module_manager.search_module(module_name)
module_dir = (os.path.join(cwd, module_name),
None) if not module_dir else module_dir
if not module_dir or not os.path.exists(module_dir[0]): if not module_dir or not os.path.exists(module_dir[0]):
print("%s is not existed!" % module_name) print("%s is not existed!" % module_name)
return True return True
......
...@@ -50,6 +50,7 @@ message CheckInfo { ...@@ -50,6 +50,7 @@ message CheckInfo {
string paddle_version = 1; string paddle_version = 1;
string hub_version = 2; string hub_version = 2;
string module_proto_version = 3; string module_proto_version = 3;
repeated FileInfo file_infos = 4; string module_code_version = 4;
repeated Requires requires = 5; repeated FileInfo file_infos = 5;
repeated Requires requires = 6;
}; };
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto # source: check_info.proto
...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo', package='paddlehub.module.checkinfo',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor( ...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=522, serialized_start=551,
serialized_end=552, serialized_end=581,
) )
_sym_db.RegisterEnumDescriptor(_FILE_TYPE) _sym_db.RegisterEnumDescriptor(_FILE_TYPE)
...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor( ...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=554, serialized_start=583,
serialized_end=645, serialized_end=674,
) )
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE) _sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='file_infos', name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=3, index=3,
number=4, number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
index=4,
number=5,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='requires', name='requires',
full_name='paddlehub.module.checkinfo.CheckInfo.requires', full_name='paddlehub.module.checkinfo.CheckInfo.requires',
index=4, index=5,
number=5, number=6,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=320, serialized_start=320,
serialized_end=520, serialized_end=549,
) )
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE _FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
...@@ -32,20 +32,22 @@ FILE_SEP = "/" ...@@ -32,20 +32,22 @@ FILE_SEP = "/"
class ModuleChecker(object): class ModuleChecker(object):
def __init__(self, module_path): def __init__(self, directory):
self.module_path = module_path self._directory = directory
self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
def generate_check_info(self): def generate_check_info(self):
check_info = check_info_pb2.CheckInfo() check_info = check_info_pb2.CheckInfo()
check_info.paddle_version = paddle.__version__ check_info.paddle_version = paddle.__version__
check_info.hub_version = hub_version check_info.hub_version = hub_version
check_info.module_proto_version = module_proto_version check_info.module_proto_version = module_proto_version
check_info.module_code_version = "v2"
file_infos = check_info.file_infos file_infos = check_info.file_infos
file_list = [file for file in os.listdir(self.module_path)] file_list = [file for file in os.listdir(self.directory)]
while file_list: while file_list:
file = file_list[0] file = file_list[0]
file_list = file_list[1:] file_list = file_list[1:]
abs_path = os.path.join(self.module_path, file) abs_path = os.path.join(self.directory, file)
if os.path.isdir(abs_path): if os.path.isdir(abs_path):
for sub_file in os.listdir(abs_path): for sub_file in os.listdir(abs_path):
sub_file = os.path.join(file, sub_file) sub_file = os.path.join(file, sub_file)
...@@ -62,9 +64,12 @@ class ModuleChecker(object): ...@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info.type = check_info_pb2.FILE file_info.type = check_info_pb2.FILE
file_info.is_need = True file_info.is_need = True
with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME), with open(self.pb_path, "wb") as file:
"wb") as fi: file.write(check_info.SerializeToString())
fi.write(check_info.SerializeToString())
@property
def module_code_version(self):
return self.check_info.module_code_version
@property @property
def module_proto_version(self): def module_proto_version(self):
...@@ -82,20 +87,25 @@ class ModuleChecker(object): ...@@ -82,20 +87,25 @@ class ModuleChecker(object):
def file_infos(self): def file_infos(self):
return self.check_info.file_infos return self.check_info.file_infos
@property
def directory(self):
return self._directory
@property
def pb_path(self):
return self._pb_path
def check(self): def check(self):
result = True result = True
self.check_info_pb_path = os.path.join(self.module_path,
CHECK_INFO_PB_FILENAME)
if not (os.path.exists(self.check_info_pb_path) if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
or os.path.isfile(self.check_info_pb_path)):
logger.warning( logger.warning(
"This module lacks core file %s" % CHECK_INFO_PB_FILENAME) "This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
result = False result = False
self.check_info = check_info_pb2.CheckInfo() self.check_info = check_info_pb2.CheckInfo()
try: try:
with open(self.check_info_pb_path, "rb") as fi: with open(self.pb_path, "rb") as fi:
pb_string = fi.read() pb_string = fi.read()
result = self.check_info.ParseFromString(pb_string) result = self.check_info.ParseFromString(pb_string)
if len(pb_string) == 0 or (result is not None if len(pb_string) == 0 or (result is not None
...@@ -182,7 +192,7 @@ class ModuleChecker(object): ...@@ -182,7 +192,7 @@ class ModuleChecker(object):
for file_info in self.file_infos: for file_info in self.file_infos:
file_type = file_info.type file_type = file_info.type
file_path = file_info.file_name.replace(FILE_SEP, os.sep) file_path = file_info.file_name.replace(FILE_SEP, os.sep)
file_path = os.path.join(self.module_path, file_path) file_path = os.path.join(self.directory, file_path)
if not os.path.exists(file_path): if not os.path.exists(file_path):
if file_info.is_need: if file_info.is_need:
logger.warning( logger.warning(
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import shutil import shutil
import tarfile
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import srv_utils from paddlehub.common import srv_utils
...@@ -77,10 +78,15 @@ class LocalModuleManager(object): ...@@ -77,10 +78,15 @@ class LocalModuleManager(object):
return self.modules_dict.get(module_name, None) return self.modules_dict.get(module_name, None)
def install_module(self, def install_module(self,
module_name, module_name=None,
module_dir=None,
module_package=None,
module_version=None, module_version=None,
upgrade=False, upgrade=False,
extra=None): extra=None):
md5_value = installed_module_version = None
from_user_dir = True if module_dir else False
if module_name:
self.all_modules(update=True) self.all_modules(update=True)
module_info = self.modules_dict.get(module_name, None) module_info = self.modules_dict.get(module_name, None)
if module_info: if module_info:
...@@ -99,8 +105,9 @@ class LocalModuleManager(object): ...@@ -99,8 +105,9 @@ class LocalModuleManager(object):
url = search_result.get('url', None) url = search_result.get('url', None)
md5_value = search_result.get('md5', None) md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None) installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version if not url or (module_version is not None
!= module_version) or (name != module_name): and installed_module_version != module_version) or (
name != module_name):
if default_hub_server._server_check() is False: if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network." tips = "Request Hub-Server unsuccessfully, please check your network."
else: else:
...@@ -123,13 +130,42 @@ class LocalModuleManager(object): ...@@ -123,13 +130,42 @@ class LocalModuleManager(object):
delete_file=True, delete_file=True,
print_progress=True) print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir: if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"), if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
if md5_value:
with open(
os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp: "w") as fp:
fp.write(md5_value) fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name) save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path): if os.path.exists(save_path):
shutil.rmtree(save_path) shutil.move(save_path)
if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
shutil.move(module_dir, save_path) shutil.move(module_dir, save_path)
module_dir = save_path module_dir = save_path
tips = "Successfully installed %s" % module_name tips = "Successfully installed %s" % module_name
......
...@@ -21,6 +21,10 @@ import os ...@@ -21,6 +21,10 @@ import os
import time import time
import sys import sys
import functools import functools
import inspect
import importlib
import tarfile
from collections import defaultdict
from shutil import copyfile from shutil import copyfile
import paddle import paddle
...@@ -28,22 +32,19 @@ import paddle.fluid as fluid ...@@ -28,22 +32,19 @@ import paddle.fluid as fluid
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import paddle_helper from paddlehub.common import paddle_helper
from paddlehub.common.logger import logger from paddlehub.common.dir import CACHE_HOME
from paddlehub.common.lock import lock from paddlehub.common.lock import lock
from paddlehub.common.downloader import default_downloader from paddlehub.common.logger import logger
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module import module_desc_pb2 from paddlehub.module import module_desc_pb2
from paddlehub.common.dir import CONF_HOME
from paddlehub.module import check_info_pb2 from paddlehub.module import check_info_pb2
from paddlehub.common.hub_server import CacheUpdater
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
from paddlehub.module.checker import ModuleChecker
from paddlehub.module.signature import Signature, create_signature
from paddlehub.module.base_processor import BaseProcessor from paddlehub.module.base_processor import BaseProcessor
from paddlehub.io.parser import yaml_parser from paddlehub.io.parser import yaml_parser
from paddlehub import version from paddlehub import version
__all__ = ['Module', 'create_module']
# PaddleHub module dir name # PaddleHub module dir name
ASSETS_DIRNAME = "assets" ASSETS_DIRNAME = "assets"
MODEL_DIRNAME = "model" MODEL_DIRNAME = "model"
...@@ -52,67 +53,226 @@ PYTHON_DIR = "python" ...@@ -52,67 +53,226 @@ PYTHON_DIR = "python"
PROCESSOR_NAME = "processor" PROCESSOR_NAME = "processor"
# PaddleHub var prefix # PaddleHub var prefix
HUB_VAR_PREFIX = "@HUB_%s@" HUB_VAR_PREFIX = "@HUB_%s@"
# PaddleHub Module package suffix
HUB_PACKAGE_SUFFIX = "phm"
def create_module(directory, name, author, email, module_type, summary,
version):
save_file_name = "{}-{}.{}".format(name, version, HUB_PACKAGE_SUFFIX)
# record module info and serialize
desc = module_desc_pb2.ModuleDesc()
attr = desc.attr
attr.type = module_desc_pb2.MAP
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(author, module_info.map.data['author'])
utils.from_pyobj_to_module_attr(email, module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(module_type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(summary, module_info.map.data['summary'])
utils.from_pyobj_to_module_attr(version, module_info.map.data['version'])
module_desc_path = os.path.join(directory, "module_desc.pb")
with open(module_desc_path, "wb") as f:
f.write(desc.SerializeToString())
# generate check info
checker = ModuleChecker(directory)
checker.generate_check_info()
# add __init__
module_init_1 = os.path.join(directory, "__init__.py")
with open(module_init_1, "a") as file:
file.write("")
module_init_2 = os.path.join(directory, "python", "__init__.py")
with open(module_init_2, "a") as file:
file.write("")
# package the module
with tarfile.open(save_file_name, "w:gz") as tar:
for dirname, _, files in os.walk(directory):
for file in files:
tar.add(os.path.join(dirname, file))
os.remove(module_desc_path)
os.remove(checker.pb_path)
os.remove(module_init_1)
os.remove(module_init_2)
class Module(object):
def __new__(cls, name=None, directory=None, module_dir=None, version=None):
module = None
if cls.__name__ == "Module":
if name:
module = cls.init_with_name(name=name, version=version)
elif directory:
module = cls.init_with_directory(directory=directory)
elif module_dir:
logger.warning(
"Parameter module_dir is deprecated, please use directory to specify the path"
)
if isinstance(module_dir, list) or isinstance(
module_dir, tuple):
directory = module_dir[0]
version = module_dir[1]
else:
directory = module_dir
module = cls.init_with_directory(directory=directory)
if not module:
module = object.__new__(cls)
CacheUpdater(name, version).start()
return module
def __init__(self, name=None, directory=None, module_dir=None,
version=None):
if not directory:
return
self._code_version = "v2"
self._directory = directory
self.module_desc_path = os.path.join(self.directory, MODULE_DESC_PBNAME)
self._desc = module_desc_pb2.ModuleDesc()
with open(self.module_desc_path, "rb") as file:
self._desc.ParseFromString(file.read())
module_info = self.desc.attr.map.data['module_info']
self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name'])
self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author'])
self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email'])
self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version'])
self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type'])
self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary'])
self._initialize()
@classmethod
def init_with_name(cls, name, version=None):
fp_lock = open(os.path.join(CACHE_HOME, name), "a")
lock.flock(fp_lock, lock.LOCK_EX)
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
logger.info(tips)
lock.flock(fp_lock, lock.LOCK_UN)
return cls.init_with_directory(directory=module_dir[0])
@classmethod
def init_with_directory(cls, directory):
desc_file = os.path.join(directory, MODULE_DESC_PBNAME)
checker = ModuleChecker(directory)
checker.check()
def create_module(sign_arr, module_code_version = checker.module_code_version
module_dir, if module_code_version == "v2":
processor=None, basename = os.path.split(directory)[-1]
assets=None, dirname = os.path.join(*list(os.path.split(directory)[:-1]))
module_info=None, sys.path.append(dirname)
exe=None, pymodule = importlib.import_module(
extra_info=None): "{}.python.module".format(basename))
sign_arr = utils.to_list(sign_arr) return pymodule.HubModule(directory=directory)
module = Module( return ModuleV1(directory=directory)
signatures=sign_arr,
processor=processor, @property
assets=assets, def desc(self):
module_info=module_info, return self._desc
extra_info=extra_info)
module.serialize_to_path(path=module_dir, exe=exe) @property
def directory(self):
return self._directory
@property
def author(self):
return self._author
@property
def author_email(self):
return self._author_email
@property
def summary(self):
return self._summary
@property
def type(self):
return self._type
@property
def version(self):
return self._version
@property
def name(self):
return self._name
@property
def name_prefix(self):
return self._name_prefix
@property
def code_version(self):
return self._code_version
@property
def is_runable(self):
return False
def _initialize(self):
pass
class ModuleHelper(object): class ModuleHelper(object):
def __init__(self, module_dir): def __init__(self, directory):
self.module_dir = module_dir self.directory = directory
def module_desc_path(self): def module_desc_path(self):
return os.path.join(self.module_dir, MODULE_DESC_PBNAME) return os.path.join(self.directory, MODULE_DESC_PBNAME)
def model_path(self): def model_path(self):
return os.path.join(self.module_dir, MODEL_DIRNAME) return os.path.join(self.directory, MODEL_DIRNAME)
def processor_path(self): def processor_path(self):
return os.path.join(self.module_dir, PYTHON_DIR) return os.path.join(self.directory, PYTHON_DIR)
def processor_name(self): def processor_name(self):
return PROCESSOR_NAME return PROCESSOR_NAME
def assets_path(self): def assets_path(self):
return os.path.join(self.module_dir, ASSETS_DIRNAME) return os.path.join(self.directory, ASSETS_DIRNAME)
class Module(object): class ModuleV1(Module):
def __init__(self, def __init__(self, name=None, directory=None, module_dir=None,
name=None,
module_dir=None,
signatures=None,
module_info=None,
assets=None,
processor=None,
extra_info=None,
version=None): version=None):
self.desc = module_desc_pb2.ModuleDesc() if not directory:
return
super(ModuleV1, self).__init__(name, directory, module_dir, version)
self._code_version = "v1"
self.program = None self.program = None
self.assets = [] self.assets = []
self.helper = None self.helper = None
self.signatures = {} self.signatures = {}
self.default_signature = None self.default_signature = None
self.module_info = None
self.processor = None self.processor = None
self.extra_info = {} if extra_info is None else extra_info self.extra_info = {}
if not isinstance(self.extra_info, dict):
raise TypeError(
"The extra_info should be an instance of python dict")
# cache data # cache data
self.last_call_name = None self.last_call_name = None
...@@ -120,62 +280,21 @@ class Module(object): ...@@ -120,62 +280,21 @@ class Module(object):
self.cache_fetch_dict = None self.cache_fetch_dict = None
self.cache_program = None self.cache_program = None
fp_lock = open(os.path.join(CONF_HOME, 'config.json')) self.helper = ModuleHelper(directory)
lock.flock(fp_lock, lock.LOCK_EX) exe = fluid.Executor(fluid.CPUPlace())
if name: self.program, _, _ = fluid.io.load_inference_model(
self._init_with_name(name=name, version=version) self.helper.model_path(), executor=exe)
lock.flock(fp_lock, lock.LOCK_UN) for block in self.program.blocks:
elif module_dir: for op in block.ops:
self._init_with_module_file(module_dir=module_dir[0]) if "op_callstack" in op.all_attrs():
lock.flock(fp_lock, lock.LOCK_UN) op._set_attr("op_callstack", [""])
name = module_dir[0].split("/")[-1] self._load_processor()
if len(module_dir) > 1: self._load_assets()
version = module_dir[1] self._recover_from_desc()
else: self._generate_sign_attr()
version = default_module_manager.search_module(name)[1] self._generate_extra_info()
elif signatures: self._restore_parameter(self.program)
if processor: self._recover_variable_info(self.program)
if not issubclass(processor, BaseProcessor):
raise TypeError(
"Processor shoule be an instance of paddlehub.BaseProcessor"
)
if assets:
self.assets = utils.to_list(assets)
# for asset in assets:
# utils.check_path(assets)
self.processor = processor
self._generate_module_info(module_info)
self._init_with_signature(signatures=signatures)
lock.flock(fp_lock, lock.LOCK_UN)
else:
lock.flock(fp_lock, lock.LOCK_UN)
raise ValueError("Module initialized parameter is empty")
CacheUpdater(name, version).start()
def _init_with_name(self, name, version=None):
log_msg = "Installing %s module" % name
if version:
log_msg += "-%s" % version
logger.info(log_msg)
extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module(
module_name=name, module_version=version, extra=extra)
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
logger.info(tips)
self._init_with_module_file(module_dir[0])
def _init_with_url(self, url):
utils.check_url(url)
result, tips, module_dir = default_downloader.download_file_and_uncompress(
url, save_path=".")
if not result:
logger.error(tips)
raise RuntimeError(tips)
else:
self._init_with_module_file(module_dir)
def _dump_processor(self): def _dump_processor(self):
import inspect import inspect
...@@ -216,52 +335,6 @@ class Module(object): ...@@ -216,52 +335,6 @@ class Module(object):
filepath = os.path.join(self.helper.assets_path(), file) filepath = os.path.join(self.helper.assets_path(), file)
self.assets.append(filepath) self.assets.append(filepath)
def _init_with_module_file(self, module_dir):
checker = ModuleChecker(module_dir)
checker.check()
self.helper = ModuleHelper(module_dir)
with open(self.helper.module_desc_path(), "rb") as fi:
self.desc.ParseFromString(fi.read())
exe = fluid.Executor(fluid.CPUPlace())
self.program, _, _ = fluid.io.load_inference_model(
self.helper.model_path(), executor=exe)
for block in self.program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
self._load_processor()
self._load_assets()
self._recover_from_desc()
self._generate_sign_attr()
self._generate_extra_info()
self._restore_parameter(self.program)
self._recover_variable_info(self.program)
def _init_with_signature(self, signatures):
self.name_prefix = HUB_VAR_PREFIX % self.name
self._process_signatures(signatures)
self._check_signatures()
self._generate_desc()
self._generate_sign_attr()
self._generate_extra_info()
def _init_with_program(self, program):
pass
def _process_signatures(self, signatures):
self.signatures = {}
self.program = signatures[0].inputs[0].block.program
for sign in signatures:
if sign.name in self.signatures:
raise ValueError(
"Error! Signature array contains duplicated signatrues %s" %
sign)
if self.default_signature is None and sign.for_predict:
self.default_signature = sign
self.signatures[sign.name] = sign
def _restore_parameter(self, program): def _restore_parameter(self, program):
global_block = program.global_block() global_block = program.global_block()
param_attrs = self.desc.attr.map.data['param_attrs'] param_attrs = self.desc.attr.map.data['param_attrs']
...@@ -302,21 +375,6 @@ class Module(object): ...@@ -302,21 +375,6 @@ class Module(object):
self.__dict__["get_%s" % key] = functools.partial( self.__dict__["get_%s" % key] = functools.partial(
self.get_extra_info, key=key) self.get_extra_info, key=key)
def _generate_module_info(self, module_info=None):
if not module_info:
self.module_info = {}
else:
if not utils.is_yaml_file(module_info):
logger.critical("Module info file should be yaml format")
exit(1)
self.module_info = yaml_parser.parse(module_info)
self.author = self.module_info.get('author', 'UNKNOWN')
self.author_email = self.module_info.get('author_email', 'UNKNOWN')
self.summary = self.module_info.get('summary', 'UNKNOWN')
self.type = self.module_info.get('type', 'UNKNOWN')
self.version = self.module_info.get('version', 'UNKNOWN')
self.name = self.module_info.get('name', 'UNKNOWN')
def _generate_sign_attr(self): def _generate_sign_attr(self):
self._check_signatures() self._check_signatures()
for sign in self.signatures: for sign in self.signatures:
...@@ -369,21 +427,21 @@ class Module(object): ...@@ -369,21 +427,21 @@ class Module(object):
default_signature_name = utils.from_module_attr_to_pyobj( default_signature_name = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data['default_signature']) self.desc.attr.map.data['default_signature'])
self.default_signature = self.signatures[ self.default_signature = self.signatures[
default_signature_name] if default_signature_name else None default_signature_name].name if default_signature_name else None
# recover module info # recover module info
module_info = self.desc.attr.map.data['module_info'] module_info = self.desc.attr.map.data['module_info']
self.name = utils.from_module_attr_to_pyobj( self._name = utils.from_module_attr_to_pyobj(
module_info.map.data['name']) module_info.map.data['name'])
self.author = utils.from_module_attr_to_pyobj( self._author = utils.from_module_attr_to_pyobj(
module_info.map.data['author']) module_info.map.data['author'])
self.author_email = utils.from_module_attr_to_pyobj( self._author_email = utils.from_module_attr_to_pyobj(
module_info.map.data['author_email']) module_info.map.data['author_email'])
self.version = utils.from_module_attr_to_pyobj( self._version = utils.from_module_attr_to_pyobj(
module_info.map.data['version']) module_info.map.data['version'])
self.type = utils.from_module_attr_to_pyobj( self._type = utils.from_module_attr_to_pyobj(
module_info.map.data['type']) module_info.map.data['type'])
self.summary = utils.from_module_attr_to_pyobj( self._summary = utils.from_module_attr_to_pyobj(
module_info.map.data['summary']) module_info.map.data['summary'])
# recover extra info # recover extra info
...@@ -393,77 +451,9 @@ class Module(object): ...@@ -393,77 +451,9 @@ class Module(object):
self.extra_info[key] = utils.from_module_attr_to_pyobj(value) self.extra_info[key] = utils.from_module_attr_to_pyobj(value)
# recover name prefix # recover name prefix
self.name_prefix = utils.from_module_attr_to_pyobj( self._name_prefix = utils.from_module_attr_to_pyobj(
self.desc.attr.map.data["name_prefix"])
def _generate_desc(self):
# save fluid Parameter
attr = self.desc.attr
attr.type = module_desc_pb2.MAP
param_attrs = attr.map.data['param_attrs']
param_attrs.type = module_desc_pb2.MAP
for param in self.program.global_block().iter_parameters():
param_attr = param_attrs.map.data[param.name]
paddle_helper.from_param_to_module_attr(param, param_attr)
# save Variable Info
var_infos = attr.map.data['var_infos']
var_infos.type = module_desc_pb2.MAP
for block in self.program.blocks:
for var in block.vars.values():
var_info = var_infos.map.data[var.name]
var_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(
var.stop_gradient, var_info.map.data['stop_gradient'])
utils.from_pyobj_to_module_attr(block.idx,
var_info.map.data['block_id'])
# save signarture info
for key, sign in self.signatures.items():
var = self.desc.sign2var[sign.name]
feed_desc = var.feed_desc
fetch_desc = var.fetch_desc
feed_names = sign.feed_names
fetch_names = sign.fetch_names
for index, input in enumerate(sign.inputs):
feed_var = feed_desc.add()
feed_var.var_name = self.get_var_name_with_prefix(input.name)
feed_var.alias = feed_names[index]
for index, output in enumerate(sign.outputs):
fetch_var = fetch_desc.add()
fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index]
# save default signature
utils.from_pyobj_to_module_attr(
self.default_signature.name if self.default_signature else None,
attr.map.data['default_signature'])
# save name prefix
utils.from_pyobj_to_module_attr(self.name_prefix,
self.desc.attr.map.data["name_prefix"]) self.desc.attr.map.data["name_prefix"])
# save module info
module_info = attr.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_module_attr(self.name, module_info.map.data['name'])
utils.from_pyobj_to_module_attr(self.version,
module_info.map.data['version'])
utils.from_pyobj_to_module_attr(self.author,
module_info.map.data['author'])
utils.from_pyobj_to_module_attr(self.author_email,
module_info.map.data['author_email'])
utils.from_pyobj_to_module_attr(self.type, module_info.map.data['type'])
utils.from_pyobj_to_module_attr(self.summary,
module_info.map.data['summary'])
# save extra info
extra_info = attr.map.data['extra_info']
extra_info.type = module_desc_pb2.MAP
for key, value in self.extra_info.items():
utils.from_pyobj_to_module_attr(value, extra_info.map.data[key])
def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs): def __call__(self, sign_name, data, use_gpu=False, batch_size=1, **kwargs):
self.check_processor() self.check_processor()
...@@ -525,6 +515,10 @@ class Module(object): ...@@ -525,6 +515,10 @@ class Module(object):
if not self.processor: if not self.processor:
raise ValueError("This Module is not callable!") raise ValueError("This Module is not callable!")
@property
def is_runable(self):
return self.default_signature != None
def context(self, def context(self,
sign_name=None, sign_name=None,
for_test=False, for_test=False,
...@@ -664,93 +658,3 @@ class Module(object): ...@@ -664,93 +658,3 @@ class Module(object):
raise ValueError( raise ValueError(
"All input and outputs variables in signature should come from the same Program" "All input and outputs variables in signature should come from the same Program"
) )
def serialize_to_path(self, path=None, exe=None):
self._check_signatures()
self._generate_desc()
# create module path for saving
if path is None:
path = os.path.join(".", self.name)
self.helper = ModuleHelper(path)
utils.mkdir(self.helper.module_dir)
# create module pb
module_desc = module_desc_pb2.ModuleDesc()
logger.info("PaddleHub version = %s" % version.hub_version)
logger.info("PaddleHub Module proto version = %s" %
version.module_proto_version)
logger.info("Paddle version = %s" % paddle.__version__)
feeded_var_names = [
input.name for key, sign in self.signatures.items()
for input in sign.inputs
]
target_vars = [
output for key, sign in self.signatures.items()
for output in sign.outputs
]
feeded_var_names = list(set(feeded_var_names))
target_vars = list(set(target_vars))
# save inference program
program = self.program.clone()
for block in program.blocks:
for op in block.ops:
if "op_callstack" in op.all_attrs():
op._set_attr("op_callstack", [""])
if not exe:
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
utils.mkdir(self.helper.model_path())
fluid.io.save_inference_model(
self.helper.model_path(),
feeded_var_names=list(feeded_var_names),
target_vars=list(target_vars),
main_program=program,
executor=exe)
with open(os.path.join(self.helper.model_path(), "__model__"),
"rb") as file:
program_desc_str = file.read()
rename_program = fluid.framework.Program.parse_from_string(
program_desc_str)
varlist = {
var: block
for block in rename_program.blocks for var in block.vars
if self.get_name_prefix() not in var
}
for var, block in varlist.items():
old_name = var
new_name = self.get_var_name_with_prefix(old_name)
block._rename_var(old_name, new_name)
utils.mkdir(self.helper.model_path())
with open(
os.path.join(self.helper.model_path(), "__model__"),
"wb") as f:
f.write(rename_program.desc.serialize_to_string())
for file in os.listdir(self.helper.model_path()):
if (file == "__model__" or self.get_name_prefix() in file):
continue
os.rename(
os.path.join(self.helper.model_path(), file),
os.path.join(self.helper.model_path(),
self.get_var_name_with_prefix(file)))
# create processor file
if self.processor:
self._dump_processor()
# create assets
self._dump_assets()
# create check info
checker = ModuleChecker(self.helper.module_dir)
checker.generate_check_info()
# Serialize module_desc pb
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册