提交 d8ed5400 编写于 作者: W wuzewu

add default signature

上级 33a0fc45
...@@ -48,7 +48,7 @@ def create_module(args): ...@@ -48,7 +48,7 @@ def create_module(args):
# create paddle hub module # create paddle hub module
assets = ["resources/label_list.txt"] assets = ["resources/label_list.txt"]
sign1 = hub.create_signature( sign1 = hub.create_signature(
"classification", inputs=[image], outputs=[predition]) "classification", inputs=[image], outputs=[predition], for_predict=True)
sign2 = hub.create_signature( sign2 = hub.create_signature(
"feature_map", inputs=[image], outputs=[feature_map]) "feature_map", inputs=[image], outputs=[feature_map])
hub.create_module( hub.create_module(
......
...@@ -37,7 +37,10 @@ def create_module(): ...@@ -37,7 +37,10 @@ def create_module():
# create a module and save as hub_module_lac # create a module and save as hub_module_lac
sign = hub.create_signature( sign = hub.create_signature(
name="lexical_analysis", inputs=[word], outputs=[crf_decode]) name="lexical_analysis",
inputs=[word],
outputs=[crf_decode],
for_predict=True)
hub.create_module( hub.create_module(
sign_arr=[sign], sign_arr=[sign],
module_dir="hub_module_lac", module_dir="hub_module_lac",
......
...@@ -42,7 +42,10 @@ def create_module(): ...@@ -42,7 +42,10 @@ def create_module():
# create a module # create a module
sign = hub.create_signature( sign = hub.create_signature(
name="sentiment_classify", inputs=[data], outputs=[pred]) name="sentiment_classify",
inputs=[data],
outputs=[pred],
for_predict=True)
hub.create_module( hub.create_module(
sign_arr=[sign], sign_arr=[sign],
module_dir="hub_module_senta", module_dir="hub_module_senta",
......
import os
import io
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_hub as hub
import numpy as np import numpy as np
import os
import io import paddle_hub as hub
from paddle_hub import BaseProcessor
from paddle_hub.hub_server import default_hub_server
from paddle_hub.module.manager import default_module_manager
def load_vocab(file_path): def load_vocab(file_path):
...@@ -37,17 +36,17 @@ def get_predict_label(pos_prob): ...@@ -37,17 +36,17 @@ def get_predict_label(pos_prob):
return label, key return label, key
class Processor(BaseProcessor): class Processor(hub.BaseProcessor):
def __init__(self, module): def __init__(self, module):
self.module = module self.module = module
assets_path = self.module.helper.assets_path() assets_path = self.module.helper.assets_path()
word_dict_path = os.path.join(assets_path, "train.vocab") word_dict_path = os.path.join(assets_path, "train.vocab")
self.word_dict = load_vocab(word_dict_path) self.word_dict = load_vocab(word_dict_path)
path = default_module_manager.search_module("lac") path = hub.default_module_manager.search_module("lac")
if path: if path:
self.lac = hub.Module(module_dir=path) self.lac = hub.Module(module_dir=path)
else: else:
result, _, path = default_module_manager.install_module("lac") result, _, path = hub.default_module_manager.install_module("lac")
assert path, "can't found necessary module lac" assert path, "can't found necessary module lac"
self.lac = hub.Module(module_dir=path) self.lac = hub.Module(module_dir=path)
......
...@@ -41,10 +41,14 @@ def create_module(): ...@@ -41,10 +41,14 @@ def create_module():
assets = ["resources/label_list.txt"] assets = ["resources/label_list.txt"]
sign = hub.create_signature( sign = hub.create_signature(
"object_detection", inputs=[image], outputs=[nmsed_out]) "object_detection",
inputs=[image],
outputs=[nmsed_out],
for_predict=True)
hub.create_module( hub.create_module(
sign_arr=[sign], sign_arr=[sign],
module_dir="hub_module_ssd", module_dir="hub_module_ssd",
module_info="resources/module_info.yml",
exe=exe, exe=exe,
processor=processor.Processor, processor=processor.Processor,
assets=assets) assets=assets)
......
...@@ -42,29 +42,10 @@ class RunCommand(BaseCommand): ...@@ -42,29 +42,10 @@ class RunCommand(BaseCommand):
# yapf: disable # yapf: disable
self.add_arg('--config', str, None, "config file in yaml format" ) self.add_arg('--config', str, None, "config file in yaml format" )
self.add_arg('--dataset', str, None, "dataset be used" ) self.add_arg('--dataset', str, None, "dataset be used" )
self.add_arg('--data', str, None, "data be used" )
self.add_arg('--signature', str, None, "signature to run" ) self.add_arg('--signature', str, None, "signature to run" )
# yapf: enable # yapf: enable
def _check_dataset(self):
if not self.args.dataset:
print("Error! Lack of dataset file")
self.help()
exit(1)
if not utils.is_csv_file(self.args.dataset):
print("Error! Dataset file should in csv format")
self.help()
exit(1)
def _check_config(self):
if not self.args.config:
print("Error! Lack of config file")
self.help()
exit(1)
if not utils.is_yaml_file(self.args.config):
print("Error! Config file should in yaml format")
self.help()
exit(1)
def exec(self, argv): def exec(self, argv):
if not argv: if not argv:
print("ERROR: Please specify a key\n") print("ERROR: Please specify a key\n")
...@@ -72,8 +53,6 @@ class RunCommand(BaseCommand): ...@@ -72,8 +53,6 @@ class RunCommand(BaseCommand):
return False return False
module_name = argv[0] module_name = argv[0]
self.args = self.parser.parse_args(argv[1:]) self.args = self.parser.parse_args(argv[1:])
self._check_dataset()
self._check_config()
module_dir = default_module_manager.search_module(module_name) module_dir = default_module_manager.search_module(module_name)
if not module_dir: if not module_dir:
...@@ -88,15 +67,40 @@ class RunCommand(BaseCommand): ...@@ -88,15 +67,40 @@ class RunCommand(BaseCommand):
return False return False
module = hub.Module(module_dir=module_dir) module = hub.Module(module_dir=module_dir)
yaml_config = yaml_reader.read(self.args.config) if not module.default_signature:
print("ERROR! Module %s is not callable" % module_name)
if not self.args.signature: if not self.args.signature:
self.args.signature = module.default_signature().name self.args.signature = module.default_signature.name
# module processor check # module processor check
module.check_processor() module.check_processor()
# data_format check
expect_data_format = module.processor.data_format(self.args.signature) expect_data_format = module.processor.data_format(self.args.signature)
# get data dict
if self.args.data:
input_data_key = list(expect_data_format.keys())[0]
origin_data = {input_data_key: [self.args.data]}
elif self.args.dataset:
origin_data = csv_reader.read(self.args.dataset)
else:
print("ERROR! Please specify data to predict")
self.help()
exit(1)
# data_format check
if not self.args.config:
assert len(expect_data_format) == 1
origin_data_key = list(origin_data.keys())[0]
input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]}
config = {}
else:
yaml_config = yaml_reader.read(self.args.config)
if len(expect_data_format) == 1:
origin_data_key = list(origin_data.keys())[0]
input_data_key = list(expect_data_format.keys())[0]
input_data = {input_data_key: origin_data[origin_data_key]}
else:
input_data_format = yaml_config['input_data'] input_data_format = yaml_config['input_data']
assert len(input_data_format) == len(expect_data_format) assert len(input_data_format) == len(expect_data_format)
for key, value in expect_data_format.items(): for key, value in expect_data_format.items():
...@@ -104,14 +108,11 @@ class RunCommand(BaseCommand): ...@@ -104,14 +108,11 @@ class RunCommand(BaseCommand):
assert value['type'] == hub.DataType.type( assert value['type'] == hub.DataType.type(
input_data_format[key]['type']) input_data_format[key]['type'])
# get data dict
origin_data = csv_reader.read(self.args.dataset)
input_data = {} input_data = {}
for key, value in yaml_config['input_data'].items(): for key, value in yaml_config['input_data'].items():
input_data[key] = origin_data[value['key']] input_data[key] = origin_data[value['key']]
# run module with data
config = yaml_config.get("config", {}) config = yaml_config.get("config", {})
# run module with data
print(module(sign_name=self.args.signature, data=input_data, **config)) print(module(sign_name=self.args.signature, data=input_data, **config))
......
...@@ -224,6 +224,8 @@ class Module(object): ...@@ -224,6 +224,8 @@ class Module(object):
for sign in signatures: for sign in signatures:
if sign.name in self.signatures: if sign.name in self.signatures:
raise "Error! signature array contains repeat signatrue %s" % sign raise "Error! signature array contains repeat signatrue %s" % sign
if self.default_signature is None and sign.for_predict:
self.default_signature = sign
self.signatures[sign.name] = sign self.signatures[sign.name] = sign
def _recovery_parameter(self, program): def _recovery_parameter(self, program):
...@@ -308,6 +310,12 @@ class Module(object): ...@@ -308,6 +310,12 @@ class Module(object):
feed_names=feed_names, feed_names=feed_names,
fetch_names=fetch_names) fetch_names=fetch_names)
# recover default signature
default_signature_name = utils.from_flexible_data_to_pyobj(
self.desc.extra_info.map.data['default_signature'])
self.default_signature = self.signatures[
default_signature_name] if default_signature_name else None
# recover module info # recover module info
module_info = self.desc.extra_info.map.data['module_info'] module_info = self.desc.extra_info.map.data['module_info']
self.name = utils.from_flexible_data_to_pyobj( self.name = utils.from_flexible_data_to_pyobj(
...@@ -362,6 +370,11 @@ class Module(object): ...@@ -362,6 +370,11 @@ class Module(object):
fetch_var.var_name = self.get_var_name_with_prefix(output.name) fetch_var.var_name = self.get_var_name_with_prefix(output.name)
fetch_var.alias = fetch_names[index] fetch_var.alias = fetch_names[index]
# save default signature
utils.from_pyobj_to_flexible_data(
self.default_signature.name if self.default_signature else None,
extra_info.map.data['default_signature'])
# save module info # save module info
module_info = extra_info.map.data['module_info'] module_info = extra_info.map.data['module_info']
module_info.type = module_desc_pb2.MAP module_info.type = module_desc_pb2.MAP
...@@ -512,15 +525,6 @@ class Module(object): ...@@ -512,15 +525,6 @@ class Module(object):
def get_var_name_with_prefix(self, var_name): def get_var_name_with_prefix(self, var_name):
return self.get_name_prefix() + var_name return self.get_name_prefix() + var_name
def parameters(self):
pass
def parameter_attrs(self):
pass
def default_signature(self):
return self.default_signature
def _check_signatures(self): def _check_signatures(self):
assert self.signatures, "signature array should not be None" assert self.signatures, "signature array should not be None"
......
...@@ -20,8 +20,13 @@ from paddle_hub.common.utils import to_list ...@@ -20,8 +20,13 @@ from paddle_hub.common.utils import to_list
class Signature: class Signature:
def __init__(self, name, inputs, outputs, feed_names=None, def __init__(self,
fetch_names=None): name,
inputs,
outputs,
feed_names=None,
fetch_names=None,
for_predict=False):
inputs = to_list(inputs) inputs = to_list(inputs)
outputs = to_list(outputs) outputs = to_list(outputs)
...@@ -52,16 +57,19 @@ class Signature: ...@@ -52,16 +57,19 @@ class Signature:
self.outputs = outputs self.outputs = outputs
self.feed_names = feed_names self.feed_names = feed_names
self.fetch_names = fetch_names self.fetch_names = fetch_names
self.for_predict = for_predict
def create_signature(name="default", def create_signature(name="default",
inputs=[], inputs=[],
outputs=[], outputs=[],
feed_names=None, feed_names=None,
fetch_names=None): fetch_names=None,
for_predict=False):
return Signature( return Signature(
name=name, name=name,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
feed_names=feed_names, feed_names=feed_names,
fetch_names=fetch_names) fetch_names=fetch_names,
for_predict=for_predict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册