提交 d8ed5400 编写于 作者: W wuzewu

add default signature

上级 33a0fc45
......@@ -48,7 +48,7 @@ def create_module(args):
# create paddle hub module
assets = ["resources/label_list.txt"]
sign1 = hub.create_signature(
"classification", inputs=[image], outputs=[predition])
"classification", inputs=[image], outputs=[predition], for_predict=True)
sign2 = hub.create_signature(
"feature_map", inputs=[image], outputs=[feature_map])
hub.create_module(
......
......@@ -37,7 +37,10 @@ def create_module():
# create a module and save as hub_module_lac
sign = hub.create_signature(
name="lexical_analysis", inputs=[word], outputs=[crf_decode])
name="lexical_analysis",
inputs=[word],
outputs=[crf_decode],
for_predict=True)
hub.create_module(
sign_arr=[sign],
module_dir="hub_module_lac",
......
......@@ -42,7 +42,10 @@ def create_module():
# create a module
sign = hub.create_signature(
name="sentiment_classify", inputs=[data], outputs=[pred])
name="sentiment_classify",
inputs=[data],
outputs=[pred],
for_predict=True)
hub.create_module(
sign_arr=[sign],
module_dir="hub_module_senta",
......
import os
import io
import paddle
import paddle.fluid as fluid
import paddle_hub as hub
import numpy as np
import os
import io
from paddle_hub import BaseProcessor
from paddle_hub.hub_server import default_hub_server
from paddle_hub.module.manager import default_module_manager
import paddle_hub as hub
def load_vocab(file_path):
......@@ -37,17 +36,17 @@ def get_predict_label(pos_prob):
return label, key
class Processor(BaseProcessor):
class Processor(hub.BaseProcessor):
def __init__(self, module):
self.module = module
assets_path = self.module.helper.assets_path()
word_dict_path = os.path.join(assets_path, "train.vocab")
self.word_dict = load_vocab(word_dict_path)
path = default_module_manager.search_module("lac")
path = hub.default_module_manager.search_module("lac")
if path:
self.lac = hub.Module(module_dir=path)
else:
result, _, path = default_module_manager.install_module("lac")
result, _, path = hub.default_module_manager.install_module("lac")
assert path, "can't found necessary module lac"
self.lac = hub.Module(module_dir=path)
......
......@@ -41,10 +41,14 @@ def create_module():
assets = ["resources/label_list.txt"]
sign = hub.create_signature(
"object_detection", inputs=[image], outputs=[nmsed_out])
"object_detection",
inputs=[image],
outputs=[nmsed_out],
for_predict=True)
hub.create_module(
sign_arr=[sign],
module_dir="hub_module_ssd",
module_info="resources/module_info.yml",
exe=exe,
processor=processor.Processor,
assets=assets)
......
......@@ -42,29 +42,10 @@ class RunCommand(BaseCommand):
# yapf: disable
self.add_arg('--config', str, None, "config file in yaml format" )
self.add_arg('--dataset', str, None, "dataset be used" )
self.add_arg('--data', str, None, "data be used" )
self.add_arg('--signature', str, None, "signature to run" )
# yapf: enable
def _check_dataset(self):
if not self.args.dataset:
print("Error! Lack of dataset file")
self.help()
exit(1)
if not utils.is_csv_file(self.args.dataset):
print("Error! Dataset file should in csv format")
self.help()
exit(1)
def _check_config(self):
if not self.args.config:
print("Error! Lack of config file")
self.help()
exit(1)
if not utils.is_yaml_file(self.args.config):
print("Error! Config file should in yaml format")
self.help()
exit(1)
def exec(self, argv):
if not argv:
print("ERROR: Please specify a key\n")
......@@ -72,8 +53,6 @@ class RunCommand(BaseCommand):
return False
module_name = argv[0]
self.args = self.parser.parse_args(argv[1:])
self._check_dataset()
self._check_config()
module_dir = default_module_manager.search_module(module_name)
if not module_dir:
......@@ -88,30 +67,52 @@ class RunCommand(BaseCommand):
return False
module = hub.Module(module_dir=module_dir)
yaml_config = yaml_reader.read(self.args.config)
if not module.default_signature:
print("ERROR! Module %s is not callable" % module_name)
if not self.args.signature:
self.args.signature = module.default_signature().name
self.args.signature = module.default_signature.name
# module processor check
module.check_processor()
# data_format check
expect_data_format = module.processor.data_format(self.args.signature)
input_data_format = yaml_config['input_data']
assert len(input_data_format) == len(expect_data_format)
for key, value in expect_data_format.items():
assert key in input_data_format
assert value['type'] == hub.DataType.type(
input_data_format[key]['type'])
# get data dict
origin_data = csv_reader.read(self.args.dataset)
input_data = {}
for key, value in yaml_config['input_data'].items():
input_data[key] = origin_data[value['key']]
if self.args.data:
input_data_key = list(expect_data_format.keys())[0]
origin_data = {input_data_key: [self.args.data]}
elif self.args.dataset:
origin_data = csv_reader.read(self.args.dataset)
else:
print("ERROR! Please specify data to predict")
self.help()
exit(1)
# data_format check
if not self.args.config:
assert len(expect_data_format) == 1
origin_data_key = list(origin_data.keys())[0]
input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]}
config = {}
else:
yaml_config = yaml_reader.read(self.args.config)
if len(expect_data_format) == 1:
origin_data_key = list(origin_data.keys())[0]
input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]}
else:
input_data_format = yaml_config['input_data']
assert len(input_data_format) == len(expect_data_format)
for key, value in expect_data_format.items():
assert key in input_data_format
assert value['type'] == hub.DataType.type(
input_data_format[key]['type'])
input_data = {}
for key, value in yaml_config['input_data'].items():
input_data[key] = origin_data[value['key']]
config = yaml_config.get("config", {})
# run module with data
config = yaml_config.get("config", {})
print(module(sign_name=self.args.signature, data=input_data, **config))
......
......@@ -224,6 +224,8 @@ class Module(object):
for sign in signatures:
if sign.name in self.signatures:
raise "Error! signature array contains repeat signatrue %s" % sign
if self.default_signature is None and sign.for_predict:
self.default_signature = sign
self.signatures[sign.name] = sign
def _recovery_parameter(self, program):
......@@ -308,6 +310,12 @@ class Module(object):
feed_names=feed_names,
fetch_names=fetch_names)
# recover default signature
default_signature_name = utils.from_flexible_data_to_pyobj(
self.desc.extra_info.map.data['default_signature'])
self.default_signature = self.signatures[
default_signature_name] if default_signature_name else None
# recover module info
module_info = self.desc.extra_info.map.data['module_info']
self.name = utils.from_flexible_data_to_pyobj(
......@@ -362,6 +370,11 @@ class Module(object):
fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index]
# save default signature
utils.from_pyobj_to_flexible_data(
self.default_signature.name if self.default_signature else None,
extra_info.map.data['default_signature'])
# save module info
module_info = extra_info.map.data['module_info']
module_info.type = module_desc_pb2.MAP
......@@ -512,15 +525,6 @@ class Module(object):
def get_var_name_with_prefix(self, var_name):
return self.get_name_prefix() + var_name
def parameters(self):
pass
def parameter_attrs(self):
pass
def default_signature(self):
return self.default_signature
def _check_signatures(self):
assert self.signatures, "signature array should not be None"
......
......@@ -20,8 +20,13 @@ from paddle_hub.common.utils import to_list
class Signature:
def __init__(self, name, inputs, outputs, feed_names=None,
fetch_names=None):
def __init__(self,
name,
inputs,
outputs,
feed_names=None,
fetch_names=None,
for_predict=False):
inputs = to_list(inputs)
outputs = to_list(outputs)
......@@ -52,16 +57,19 @@ class Signature:
self.outputs = outputs
self.feed_names = feed_names
self.fetch_names = fetch_names
self.for_predict = for_predict
def create_signature(name="default",
inputs=[],
outputs=[],
feed_names=None,
fetch_names=None):
fetch_names=None,
for_predict=False):
return Signature(
name=name,
inputs=inputs,
outputs=outputs,
feed_names=feed_names,
fetch_names=fetch_names)
fetch_names=fetch_names,
for_predict=for_predict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册