提交 6669e041 编写于 作者: W wuzewu

Add module v2

上级 5797e613
......@@ -38,7 +38,7 @@ from .common.logger import logger
from .common.paddle_helper import connect_program
from .common.hub_server import default_hub_server
from .module.module import Module, create_module
from .module.module import Module
from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature
from .module.manager import default_module_manager
......
......@@ -71,10 +71,10 @@ class RunCommand(BaseCommand):
if not result:
return None
return hub.Module(module_dir=module_dir)
return hub.Module(directory=module_dir[0])
def add_module_config_arg(self):
configs = self.module.processor.configs()
configs = self.module.configs()
for config in configs:
if not config["dest"].startswith("--"):
config["dest"] = "--%s" % config["dest"]
......@@ -104,8 +104,8 @@ class RunCommand(BaseCommand):
def add_module_input_arg(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
expect_data_format = self.module.data_format(
self.module.default_signature)
self.arg_input_group.add_argument(
'--input_file',
type=str,
......@@ -144,15 +144,15 @@ class RunCommand(BaseCommand):
if self.args.config:
yaml_config = yaml_parser.parse(self.args.config)
module_config = yaml_config.get("config", {})
for _config in self.module.processor.configs():
for _config in self.module.configs():
key = _config['dest']
module_config[key] = self.args.__dict__[key]
return module_config
def get_data(self):
module_type = self.module.type.lower()
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
expect_data_format = self.module.data_format(
self.module.default_signature)
input_data = {}
if len(expect_data_format) == 1:
key = list(expect_data_format.keys())[0]
......@@ -176,8 +176,8 @@ class RunCommand(BaseCommand):
return input_data
def check_data(self, data):
expect_data_format = self.module.processor.data_format(
self.module.default_signature.name)
expect_data_format = self.module.data_format(
self.module.default_signature)
if len(data.keys()) != len(expect_data_format.keys()):
print(
......@@ -260,7 +260,7 @@ class RunCommand(BaseCommand):
return False
results = self.module(
sign_name=self.module.default_signature.name,
sign_name=self.module.default_signature,
data=data,
use_gpu=self.args.use_gpu,
batch_size=self.args.batch_size,
......
......@@ -50,6 +50,7 @@ message CheckInfo {
string paddle_version = 1;
string hub_version = 2;
string module_proto_version = 3;
repeated FileInfo file_infos = 4;
repeated Requires requires = 5;
string module_code_version = 4;
repeated FileInfo file_infos = 5;
repeated Requires requires = 6;
};
#coding:utf-8
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: check_info.proto
......@@ -19,7 +18,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='paddlehub.module.checkinfo',
syntax='proto3',
serialized_pb=_b(
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xc8\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x38\n\nfile_infos\x18\x04 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
'\n\x10\x63heck_info.proto\x12\x1apaddlehub.module.checkinfo\"\x85\x01\n\x08\x46ileInfo\x12\x11\n\tfile_name\x18\x01 \x01(\t\x12\x33\n\x04type\x18\x02 \x01(\x0e\x32%.paddlehub.module.checkinfo.FILE_TYPE\x12\x0f\n\x07is_need\x18\x03 \x01(\x08\x12\x0b\n\x03md5\x18\x04 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x05 \x01(\t\"\x84\x01\n\x08Requires\x12>\n\x0crequire_type\x18\x01 \x01(\x0e\x32(.paddlehub.module.checkinfo.REQUIRE_TYPE\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\ngreat_than\x18\x03 \x01(\x08\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\"\xe5\x01\n\tCheckInfo\x12\x16\n\x0epaddle_version\x18\x01 \x01(\t\x12\x13\n\x0bhub_version\x18\x02 \x01(\t\x12\x1c\n\x14module_proto_version\x18\x03 \x01(\t\x12\x1b\n\x13module_code_version\x18\x04 \x01(\t\x12\x38\n\nfile_infos\x18\x05 \x03(\x0b\x32$.paddlehub.module.checkinfo.FileInfo\x12\x36\n\x08requires\x18\x06 \x03(\x0b\x32$.paddlehub.module.checkinfo.Requires*\x1e\n\tFILE_TYPE\x12\x08\n\x04\x46ILE\x10\x00\x12\x07\n\x03\x44IR\x10\x01*[\n\x0cREQUIRE_TYPE\x12\x12\n\x0ePYTHON_PACKAGE\x10\x00\x12\x0e\n\nHUB_MODULE\x10\x01\x12\n\n\x06SYSTEM\x10\x02\x12\x0b\n\x07\x43OMMAND\x10\x03\x12\x0e\n\nPY_VERSION\x10\x04\x42\x02H\x03\x62\x06proto3'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
......@@ -36,8 +35,8 @@ _FILE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=522,
serialized_end=552,
serialized_start=551,
serialized_end=581,
)
_sym_db.RegisterEnumDescriptor(_FILE_TYPE)
......@@ -61,8 +60,8 @@ _REQUIRE_TYPE = _descriptor.EnumDescriptor(
],
containing_type=None,
options=None,
serialized_start=554,
serialized_end=645,
serialized_start=583,
serialized_end=674,
)
_sym_db.RegisterEnumDescriptor(_REQUIRE_TYPE)
......@@ -316,10 +315,26 @@ _CHECKINFO = _descriptor.Descriptor(
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
name='module_code_version',
full_name='paddlehub.module.checkinfo.CheckInfo.module_code_version',
index=3,
number=4,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='file_infos',
full_name='paddlehub.module.checkinfo.CheckInfo.file_infos',
index=4,
number=5,
type=11,
cpp_type=10,
label=3,
......@@ -334,8 +349,8 @@ _CHECKINFO = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='requires',
full_name='paddlehub.module.checkinfo.CheckInfo.requires',
index=4,
number=5,
index=5,
number=6,
type=11,
cpp_type=10,
label=3,
......@@ -357,7 +372,7 @@ _CHECKINFO = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[],
serialized_start=320,
serialized_end=520,
serialized_end=549,
)
_FILEINFO.fields_by_name['type'].enum_type = _FILE_TYPE
......
......@@ -32,20 +32,22 @@ FILE_SEP = "/"
class ModuleChecker(object):
def __init__(self, module_path):
self.module_path = module_path
def __init__(self, directory):
self._directory = directory
self._pb_path = os.path.join(self.directory, CHECK_INFO_PB_FILENAME)
def generate_check_info(self):
check_info = check_info_pb2.CheckInfo()
check_info.paddle_version = paddle.__version__
check_info.hub_version = hub_version
check_info.module_proto_version = module_proto_version
check_info.module_code_version = "v2"
file_infos = check_info.file_infos
file_list = [file for file in os.listdir(self.module_path)]
file_list = [file for file in os.listdir(self.directory)]
while file_list:
file = file_list[0]
file_list = file_list[1:]
abs_path = os.path.join(self.module_path, file)
abs_path = os.path.join(self.directory, file)
if os.path.isdir(abs_path):
for sub_file in os.listdir(abs_path):
sub_file = os.path.join(file, sub_file)
......@@ -62,9 +64,12 @@ class ModuleChecker(object):
file_info.type = check_info_pb2.FILE
file_info.is_need = True
with open(os.path.join(self.module_path, CHECK_INFO_PB_FILENAME),
"wb") as fi:
fi.write(check_info.SerializeToString())
with open(self.pb_path, "wb") as file:
file.write(check_info.SerializeToString())
@property
def module_code_version(self):
return self.check_info.module_code_version
@property
def module_proto_version(self):
......@@ -82,20 +87,25 @@ class ModuleChecker(object):
def file_infos(self):
return self.check_info.file_infos
@property
def directory(self):
return self._directory
@property
def pb_path(self):
return self._pb_path
def check(self):
result = True
self.check_info_pb_path = os.path.join(self.module_path,
CHECK_INFO_PB_FILENAME)
if not (os.path.exists(self.check_info_pb_path)
or os.path.isfile(self.check_info_pb_path)):
if not (os.path.exists(self.pb_path) or os.path.isfile(self.pb_path)):
logger.warning(
"This module lacks core file %s" % CHECK_INFO_PB_FILENAME)
result = False
self.check_info = check_info_pb2.CheckInfo()
try:
with open(self.check_info_pb_path, "rb") as fi:
with open(self.pb_path, "rb") as fi:
pb_string = fi.read()
result = self.check_info.ParseFromString(pb_string)
if len(pb_string) == 0 or (result is not None
......@@ -182,7 +192,7 @@ class ModuleChecker(object):
for file_info in self.file_infos:
file_type = file_info.type
file_path = file_info.file_name.replace(FILE_SEP, os.sep)
file_path = os.path.join(self.module_path, file_path)
file_path = os.path.join(self.directory, file_path)
if not os.path.exists(file_path):
if file_info.is_need:
logger.warning(
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册