提交 467ce09a 编写于 作者: W wuzewu

add processor dump func

上级 89213e05
......@@ -17,6 +17,7 @@ from . import tools
from . import data
from .paddle_extend import regularizer
from .module.module import Module, create_module
from .module.base_processor import BaseProcessor
from .module.signature import Signature, create_signature
from .tools.logger import logger
from .tools.paddle_helper import connect_program
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class BaseProcessor:
def __init__(self):
pass
def reader(self, sign_name, data_dict):
raise NotImplementedError("BaseProcessor' reader should not be call!")
def postprocess(self, sign_name, data_out, config):
raise NotImplementedError(
"BaseProcessor' postprocess should not be call!")
def data_format(self, sign_name):
raise NotImplementedError(
"BaseProcessor' data_format should not be call!")
......@@ -21,7 +21,9 @@ from paddle_hub.tools import downloader
from paddle_hub.tools import paddle_helper
from paddle_hub.module import module_desc_pb2
from paddle_hub.module.signature import Signature, create_signature
from paddle_hub.data.reader import yaml_reader
from paddle_hub import version
from paddle_hub.module.base_processor import BaseProcessor
import os
import functools
import paddle
......@@ -30,9 +32,18 @@ import paddle.fluid as fluid
__all__ = ['Module', 'create_module']
def create_module(sign_arr, module_dir, exe=None):
def create_module(sign_arr,
module_dir,
processor,
assets=None,
module_info=None,
exe=None):
sign_arr = utils.to_list(sign_arr)
module = Module(signatures=sign_arr)
module = Module(
signatures=sign_arr,
processor=processor,
assets=assets,
module_info=module_info)
module.serialize_to_path(path=module_dir, exe=exe)
......@@ -76,21 +87,32 @@ class ModuleHelper:
class Module:
def __init__(self, url=None, module_dir=None, signatures=None, name=None):
if not name:
name = "HubModule"
self.name = name
def __init__(self,
url=None,
module_dir=None,
signatures=None,
module_info=None,
assets=None,
processor=None):
self.desc = module_desc_pb2.ModuleDesc()
self.program = None
self.assets = []
self.helper = None
self.signatures = {}
self.default_signature = None
self.module_info = None
self.processor = None
if url:
self._init_with_url(url=url)
elif module_dir:
self._init_with_module_file(module_dir=module_dir)
elif signatures:
assert processor, "lack of module processor"
assert issubclass(
processor, BaseProcessor
), "processor should be sub class of hub.BaseProcessor"
self.processor = processor
self._generate_module_info(module_info)
self._init_with_signature(signatures=signatures)
else:
raise "Error! HubModule Can't init with nothing"
......@@ -100,6 +122,17 @@ class Module:
module_dir = downloader.download_and_uncompress(module_url)
self._init_with_module_file(module_dir)
def _dump_processor(self):
import inspect
pymodule = inspect.getmodule(self.processor)
pycode = inspect.getsource(pymodule)
processor_path = self.helper.processor_path()
processor_name = self.helper.processor_name()
output_file = os.path.join(processor_path, processor_name + ".py")
utils.mkdir(processor_path)
with open(output_file, "w") as file:
file.write(pycode)
def _load_processor(self):
import sys
processor_path = self.helper.processor_path()
......@@ -118,29 +151,7 @@ class Module:
self._recovery_parameter(self.program)
self._recover_variable_info(self.program)
self._load_processor()
inputs = []
outputs = []
feed_names = []
fetch_names = []
for sign, module_var in self.desc.sign2var.items():
for var in module_var.feed_desc:
variable = self.program.global_block().vars[var.var_name]
inputs.append(variable)
feed_names.append(var.alias)
for var in module_var.fetch_desc:
variable = self.program.global_block().vars[var.var_name]
outputs.append(variable)
fetch_names.append(var.alias)
self.signatures[sign] = create_signature(
sign,
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names)
self._recover_from_desc()
self._generate_sign_attr()
def _init_with_signature(self, signatures):
......@@ -192,12 +203,74 @@ class Module:
var = block.vars[var_name]
var.stop_gradient = stop_gradient
def _generate_module_info(self, module_info=None):
if not module_info:
self.module_info = {}
else:
if not utils.is_yaml_file(module_info):
logger.critical("module info file should in yaml format")
exit(1)
module_info = yaml_reader.read(module_info)
self.author = module_info.get('author', 'UNKNOWN')
self.author_email = module_info.get('author_email', 'UNKNOWN')
self.summary = module_info.get('summary', 'UNKNOWN')
self.type = module_info.get('type', 'UNKNOWN')
self.version = module_info.get('version', 'UNKNOWN')
self.name = module_info.get('name', 'UNKNOWN')
# self.author = module_info['author'] if 'author' in module_info else "UNKNOWN"
# self.author_email = module_info['author_email'] if 'author_email' in module_info else "UNKNOWN"
# self.summary = module_info['summary'] if 'summary' in module_info else "UNKNOWN"
# self.type = module_info['type'] if 'type' in module_info else "UNKNOWN"
# self.version = module_info['version'] if 'version' in module_info else "UNKNOWN"
# self.name = module_info['name'] if 'name' in module_info else "UNKNOWN"
def _generate_sign_attr(self):
self._check_signatures()
for sign in self.signatures:
self.__dict__[sign] = functools.partial(
self.__call__, sign_name=sign)
def _recover_from_desc(self):
# recover signature
for sign, module_var in self.desc.sign2var.items():
inputs = []
outputs = []
feed_names = []
fetch_names = []
for var in module_var.feed_desc:
variable = self.program.global_block().vars[var.var_name]
inputs.append(variable)
feed_names.append(var.alias)
for var in module_var.fetch_desc:
variable = self.program.global_block().vars[var.var_name]
outputs.append(variable)
fetch_names.append(var.alias)
self.signatures[sign] = create_signature(
sign,
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names)
# recover module info
module_info = self.desc.extra_info.map.data['module_info']
self.name = utils.from_flexible_data_to_pyobj(
module_info.map.data['name'])
self.author = utils.from_flexible_data_to_pyobj(
module_info.map.data['author'])
self.author_email = utils.from_flexible_data_to_pyobj(
module_info.map.data['author_email'])
self.version = utils.from_flexible_data_to_pyobj(
module_info.map.data['version'])
self.type = utils.from_flexible_data_to_pyobj(
module_info.map.data['type'])
self.summary = utils.from_flexible_data_to_pyobj(
module_info.map.data['summary'])
def _generate_desc(self):
# save fluid Parameter
extra_info = self.desc.extra_info
......@@ -237,6 +310,22 @@ class Module:
fetch_var.var_name = HUB_VAR_PREFIX + output.name
fetch_var.alias = fetch_names[index]
# save module info
module_info = extra_info.map.data['module_info']
module_info.type = module_desc_pb2.MAP
utils.from_pyobj_to_flexible_data(self.name,
module_info.map.data['name'])
utils.from_pyobj_to_flexible_data(self.version,
module_info.map.data['version'])
utils.from_pyobj_to_flexible_data(self.author,
module_info.map.data['author'])
utils.from_pyobj_to_flexible_data(self.author_email,
module_info.map.data['author_email'])
utils.from_pyobj_to_flexible_data(self.type,
module_info.map.data['type'])
utils.from_pyobj_to_flexible_data(self.summary,
module_info.map.data['summary'])
def __call__(self, sign_name, data, config=None):
feed_dict, fetch_dict, program = self.context(sign_name)
#TODO(wuzewu): more option
......@@ -383,3 +472,6 @@ class Module:
module_pb = self.desc.SerializeToString()
with open(self.helper.module_desc_path(), "wb") as f:
f.write(module_pb)
# create processor file
self._dump_processor()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册