未验证 提交 a25545de 编写于 作者: W wuzewu 提交者: GitHub

Add Module V2 support (#274)

* Add module v2
上级 b2dc77ed
...@@ -38,7 +38,7 @@ from .common.logger import logger ...@@ -38,7 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server from .common.hub_server import default_hub_server
from .module.module import Module, create_module from .module.module import Module
from .module.base_processor import BaseProcessor from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature from .module.signature import Signature, create_signature
from .module.manager import default_module_manager from .module.manager import default_module_manager
......
...@@ -18,6 +18,7 @@ from __future__ import division ...@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.module.manager import default_module_manager from paddlehub.module.manager import default_module_manager
...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand): ...@@ -42,14 +43,23 @@ class InstallCommand(BaseCommand):
print("ERROR: Please specify a module name.\n") print("ERROR: Please specify a module name.\n")
self.help() self.help()
return False return False
module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split(
"==")[1]
module_name = module_name if "==" not in module_name else module_name.split(
"==")[0]
extra = {"command": "install"} extra = {"command": "install"}
result, tips, module_dir = default_module_manager.install_module( if argv[0].endswith("tar.gz") or argv[0].endswith("phm"):
module_name=module_name, module_version=module_version, extra=extra) result, tips, module_dir = default_module_manager.install_module(
module_package=argv[0], extra=extra)
elif os.path.exists(argv[0]) and os.path.isdir(argv[0]):
result, tips, module_dir = default_module_manager.install_module(
module_dir=argv[0], extra=extra)
else:
module_name = argv[0]
module_version = None if "==" not in module_name else module_name.split(
"==")[1]
module_name = module_name if "==" not in module_name else module_name.split(
"==")[0]
result, tips, module_dir = default_module_manager.install_module(
module_name=module_name,
module_version=module_version,
extra=extra)
print(tips) print(tips)
return True return True
......
...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand): ...@@ -71,7 +71,7 @@ class RunCommand(BaseCommand):
if not result: if not result:
return None return None
return hub.Module(module_dir=module_dir) return hub.Module(directory=module_dir[0])
def add_module_config_arg(self): def add_module_config_arg(self):
configs = self.module.processor.configs() configs = self.module.processor.configs()
...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand): ...@@ -105,7 +105,7 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self): def add_module_input_arg(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
self.arg_input_group.add_argument( self.arg_input_group.add_argument(
'--input_file', '--input_file',
type=str, type=str,
...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand): ...@@ -152,7 +152,7 @@ class RunCommand(BaseCommand):
def get_data(self): def get_data(self):
module_type = self.module.type.lower() module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
input_data = {} input_data = {}
if len(expect_data_format) == 1: if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0] key = list(expect_data_format.keys())[0]
...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand): ...@@ -177,7 +177,7 @@ class RunCommand(BaseCommand):
def check_data(self, data): def check_data(self, data):
expect_data_format = self.module.processor.data_format( expect_data_format = self.module.processor.data_format(
self.module.default_signature.name) self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()): if len(data.keys()) != len(expect_data_format.keys()):
print( print(
...@@ -236,35 +236,38 @@ class RunCommand(BaseCommand): ...@@ -236,35 +236,38 @@ class RunCommand(BaseCommand):
return False return False
# If the module is not executable, give an alarm and exit # If the module is not executable, give an alarm and exit
if not self.module.default_signature: if not self.module.is_runable:
print("ERROR! Module %s is not executable." % module_name) print("ERROR! Module %s is not executable." % module_name)
return False return False
self.module.check_processor() if self.module.code_version == "v2":
self.add_module_config_arg() results = self.module(argv[1:])
self.add_module_input_arg() else:
self.module.check_processor()
self.add_module_config_arg()
self.add_module_input_arg()
if not argv[1:]: if not argv[1:]:
self.help() self.help()
return False return False
self.args = self.parser.parse_args(argv[1:]) self.args = self.parser.parse_args(argv[1:])
config = self.get_config() config = self.get_config()
data = self.get_data() data = self.get_data()
try: try:
self.check_data(data) self.check_data(data)
except DataFormatError: except DataFormatError:
self.help() self.help()
return False return False
results = self.module( results = self.module(
sign_name=self.module.default_signature.name, sign_name=self.module.default_signature,
data=data, data=data,
use_gpu=self.args.use_gpu, use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size, batch_size=self.args.batch_size,
**config) **config)
if six.PY2: if six.PY2:
try: try:
......
...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand): ...@@ -125,8 +125,6 @@ class ShowCommand(BaseCommand):
cwd = os.getcwd() cwd = os.getcwd()
module_dir = default_module_manager.search_module(module_name) module_dir = default_module_manager.search_module(module_name)
module_dir = (os.path.join(cwd, module_name),
None) if not module_dir else module_dir
if not module_dir or not os.path.exists(module_dir[0]): if not module_dir or not os.path.exists(module_dir[0]):
print("%s is not existed!" % module_name) print("%s is not existed!" % module_name)
return True return True
......
...@@ -50,6 +50,7 @@ message CheckInfo { ...@@ -50,6 +50,7 @@ message CheckInfo {
string paddle_version = 1; string paddle_version = 1;
string hub_version = 2; string hub_version = 2;
string module_proto_version = 3; string module_proto_version = 3;
repeated FileInfo file_infos = 4; string module_code_version = 4;
repeated Requires requires = 5; repeated FileInfo file_infos = 5;
repeated Requires requires = 6;
}; };
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto # source: check_info.proto
...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( ...@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo', package='paddlehub.module.checkinfo',
syntax='proto3', syntax='proto3',
serialized_pb=_b( serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3' '\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
)) ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)
...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor( ...@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=522, serialized_start=551,
serialized_end=552, serialized_end=581,
) )
_sym_db.RegisterEnumDescriptor(_FILE_TYPE) _sym_db.RegisterEnumDescriptor(_FILE_TYPE)
...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor( ...@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
], ],
containing_type=None, containing_type=None,
options=None, options=None,
serialized_start=554, serialized_start=583,
serialized_end=645, serialized_end=674,
) )
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE) _sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope=None, extension_scope=None,
options=None), options=None),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='file_infos', name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos', full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=3, index=3,
number=4, number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
index=4,
number=5,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='requires', name='requires',
full_name='paddlehub.module.checkinfo.CheckInfo.requires', full_name='paddlehub.module.checkinfo.CheckInfo.requires',
index=4, index=5,
number=5, number=6,
type=11, type=11,
cpp_type=10, cpp_type=10,
label=3, label=3,
...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor( ...@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=320, serialized_start=320,
serialized_end=520, serialized_end=549,
) )
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE _FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
...@@ -32,20 +32,22 @@ FILE_SEP = "/" ...@@ -32,20 +32,22 @@ FILE_SEP = "/"
class ModuleChecker(object): class ModuleChecker(object):
def __init__(self, module_path): def __init__(self, directory):
self.module_path = module_path self._directory = directory
self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
def generate_check_info(self): def generate_check_info(self):
check_info = check_info_pb2.CheckInfo() check_info = check_info_pb2.CheckInfo()
check_info.paddle_version = paddle.__version__ check_info.paddle_version = paddle.__version__
check_info.hub_version = hub_version check_info.hub_version = hub_version
check_info.module_proto_version = module_proto_version check_info.module_proto_version = module_proto_version
check_info.module_code_version = "v2"
file_infos = check_info.file_infos file_infos = check_info.file_infos
file_list = [file for file in os.listdir(self.module_path)] file_list = [file for file in os.listdir(self.directory)]
while file_list: while file_list:
file = file_list[0] file = file_list[0]
file_list = file_list[1:] file_list = file_list[1:]
abs_path = os.path.join(self.module_path, file) abs_path = os.path.join(self.directory, file)
if os.path.isdir(abs_path): if os.path.isdir(abs_path):
for sub_file in os.listdir(abs_path): for sub_file in os.listdir(abs_path):
sub_file = os.path.join(file, sub_file) sub_file = os.path.join(file, sub_file)
...@@ -62,9 +64,12 @@ class ModuleChecker(object): ...@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info.type = check_info_pb2.FILE file_info.type = check_info_pb2.FILE
file_info.is_need = True file_info.is_need = True
with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME), with open(self.pb_path, "wb") as file:
"wb") as fi: file.write(check_info.SerializeToString())
fi.write(check_info.SerializeToString())
@property
def module_code_version(self):
return self.check_info.module_code_version
@property @property
def module_proto_version(self): def module_proto_version(self):
...@@ -82,20 +87,25 @@ class ModuleChecker(object): ...@@ -82,20 +87,25 @@ class ModuleChecker(object):
def file_infos(self): def file_infos(self):
return self.check_info.file_infos return self.check_info.file_infos
@property
def directory(self):
return self._directory
@property
def pb_path(self):
return self._pb_path
def check(self): def check(self):
result = True result = True
self.check_info_pb_path = os.path.join(self.module_path,
CHECK_INFO_PB_FILENAME)
if not (os.path.exists(self.check_info_pb_path) if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
or os.path.isfile(self.check_info_pb_path)):
logger.warning( logger.warning(
"This module lacks core file %s" % CHECK_INFO_PB_FILENAME) "This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
result = False result = False
self.check_info = check_info_pb2.CheckInfo() self.check_info = check_info_pb2.CheckInfo()
try: try:
with open(self.check_info_pb_path, "rb") as fi: with open(self.pb_path, "rb") as fi:
pb_string = fi.read() pb_string = fi.read()
result = self.check_info.ParseFromString(pb_string) result = self.check_info.ParseFromString(pb_string)
if len(pb_string) == 0 or (result is not None if len(pb_string) == 0 or (result is not None
...@@ -182,7 +192,7 @@ class ModuleChecker(object): ...@@ -182,7 +192,7 @@ class ModuleChecker(object):
for file_info in self.file_infos: for file_info in self.file_infos:
file_type = file_info.type file_type = file_info.type
file_path = file_info.file_name.replace(FILE_SEP, os.sep) file_path = file_info.file_name.replace(FILE_SEP, os.sep)
file_path = os.path.join(self.module_path, file_path) file_path = os.path.join(self.directory, file_path)
if not os.path.exists(file_path): if not os.path.exists(file_path):
if file_info.is_need: if file_info.is_need:
logger.warning( logger.warning(
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
import shutil import shutil
import tarfile
from paddlehub.common import utils from paddlehub.common import utils
from paddlehub.common import srv_utils from paddlehub.common import srv_utils
...@@ -77,15 +78,76 @@ class LocalModuleManager(object): ...@@ -77,15 +78,76 @@ class LocalModuleManager(object):
return self.modules_dict.get(module_name, None) return self.modules_dict.get(module_name, None)
def install_module(self, def install_module(self,
module_name, module_name=None,
module_dir=None,
module_package=None,
module_version=None, module_version=None,
upgrade=False, upgrade=False,
extra=None): extra=None):
self.all_modules(update=True) md5_value = installed_module_version = None
module_info = self.modules_dict.get(module_name, None) from_user_dir = True if module_dir else False
if module_info: if module_name:
if not module_version or module_version == self.modules_dict[ self.all_modules(update=True)
module_name][1]: module_info = self.modules_dict.get(module_name, None)
if module_info:
if not module_version or module_version == self.modules_dict[
module_name][1]:
module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
tips = "Module %s already installed in %s" % (module_tag,
module_dir)
return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url(
module_name, version=module_version, extra=extra)
name = search_result.get('name', None)
url = search_result.get('url', None)
md5_value = search_result.get('md5', None)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None
and installed_module_version != module_version) or (
name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_package:
with tarfile.open(module_package, "r:gz") as tar:
file_names = tar.getnames()
size = len(file_names) - 1
module_dir = os.path.split(file_names[0])[0]
module_dir = os.path.join(hub.CACHE_HOME, module_dir)
# remove cache
if os.path.exists(module_dir):
shutil.rmtree(module_dir)
for index, file_name in enumerate(file_names):
tar.extract(file_name, hub.CACHE_HOME)
if module_dir:
if not module_name:
module_name = hub.Module(directory=module_dir).name
self.all_modules(update=False)
module_info = self.modules_dict.get(module_name, None)
if module_info:
module_dir = self.modules_dict[module_name][0] module_dir = self.modules_dict[module_name][0]
module_tag = module_name if not module_version else '%s-%s' % ( module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version) module_name, module_version)
...@@ -93,44 +155,18 @@ class LocalModuleManager(object): ...@@ -93,44 +155,18 @@ class LocalModuleManager(object):
module_dir) module_dir)
return True, tips, self.modules_dict[module_name] return True, tips, self.modules_dict[module_name]
search_result = hub.default_hub_server.get_module_url( if md5_value:
module_name, version=module_version, extra=extra) with open(
name = search_result.get('name', None) os.path.join(MODULE_HOME, module_dir, "md5.txt"),
url = search_result.get('url', None) "w") as fp:
md5_value = search_result.get('md5', None) fp.write(md5_value)
installed_module_version = search_result.get('version', None)
if not url or (module_version is not None and installed_module_version
!= module_version) or (name != module_name):
if default_hub_server._server_check() is False:
tips = "Request Hub-Server unsuccessfully, please check your network."
else:
tips = "Can't find module %s" % module_name
if module_version:
tips += " with version %s" % module_version
module_tag = module_name if not module_version else '%s-%s' % (
module_name, module_version)
return False, tips, None
result, tips, module_zip_file = default_downloader.download_file(
url=url,
save_path=hub.CACHE_HOME,
save_name=module_name,
replace=True,
print_progress=True)
result, tips, module_dir = default_downloader.uncompress(
file=module_zip_file,
dirname=MODULE_HOME,
delete_file=True,
print_progress=True)
if module_dir:
with open(os.path.join(MODULE_HOME, module_dir, "md5.txt"),
"w") as fp:
fp.write(md5_value)
save_path = os.path.join(MODULE_HOME, module_name) save_path = os.path.join(MODULE_HOME, module_name)
if os.path.exists(save_path): if os.path.exists(save_path):
shutil.rmtree(save_path) shutil.move(save_path)
shutil.move(module_dir, save_path) if from_user_dir:
shutil.copytree(module_dir, save_path)
else:
shutil.move(module_dir, save_path)
module_dir = save_path module_dir = save_path
tips = "Successfully installed %s" % module_name tips = "Successfully installed %s" % module_name
if installed_module_version: if installed_module_version:
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册