提交 fe10b94f 编写于 作者: W wuzewu

Improve the model export interface.

上级 23493590
...@@ -282,6 +282,7 @@ class ModuleV1(object): ...@@ -282,6 +282,7 @@ class ModuleV1(object):
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename) params_filename=params_filename)
@paddle_utils.run_in_static_mode
def export_onnx_model(self, dirname: str, **kwargs): def export_onnx_model(self, dirname: str, **kwargs):
''' '''
Export the model to ONNX format. Export the model to ONNX format.
...@@ -289,8 +290,8 @@ class ModuleV1(object): ...@@ -289,8 +290,8 @@ class ModuleV1(object):
Args: Args:
dirname(str): The directory to save the onnx model. dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed **kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information. for more information.
''' '''
feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False) feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False)
inputs = set([var.name for var in feed_dict.values()]) inputs = set([var.name for var in feed_dict.values()])
...@@ -308,3 +309,12 @@ class ModuleV1(object): ...@@ -308,3 +309,12 @@ class ModuleV1(object):
target_vars=outputs, target_vars=outputs,
save_file=save_file, save_file=save_file,
**kwargs) **kwargs)
def sub_modules(self, recursive: bool = True):
'''
Get all sub modules.
Args:
recursive(bool): Whether to get sub modules recursively. Default to True.
'''
return []
...@@ -135,7 +135,7 @@ class RunModule(object): ...@@ -135,7 +135,7 @@ class RunModule(object):
@property @property
def _pretrained_model_path(self): def _pretrained_model_path(self):
_pretrained_model_attrs = [ _pretrained_model_attrs = [
'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path' 'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path', 'model_path'
] ]
for _attr in _pretrained_model_attrs: for _attr in _pretrained_model_attrs:
...@@ -147,30 +147,77 @@ class RunModule(object): ...@@ -147,30 +147,77 @@ class RunModule(object):
return None return None
def export_onnx_model(self, dirname: str, **kwargs): def sub_modules(self, recursive: bool = True):
'''
Get all sub modules.
Args:
recursive(bool): Whether to get sub modules recursively. Default to True.
'''
_sub_modules = {}
for key, item in self.__dict__.items():
if id(item) == id(self):
continue
if isinstance(item, (RunModule, ModuleV1)):
_sub_modules[key] = item
if not recursive:
continue
for _k, _v in item.sub_modules(recursive):
_sub_modules['{}/{}'.format(key, _k)] = _v
return _sub_modules
def export_onnx_model(self,
dirname: str,
input_spec: List[paddle.static.InputSpec] = None,
export_sub_modules: bool = True,
**kwargs):
''' '''
Export the model to ONNX format. Export the model to ONNX format.
Args: Args:
dirname(str): The directory to save the onnx model. dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed input_spec(list): Describes the input of the saved model's forward method, which can be described by
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx InputSpec or example Tensor. If None, all input variables of the original Layer's forward method
for more information. would be the inputs of the saved model. Default None.
export_sub_modules(bool): Whether to export sub modules. Default to True.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed in
the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
''' '''
if export_sub_modules:
for key, _sub_module in self.sub_modules().items():
try:
sub_dirname = os.path.normpath(os.path.join(dirname, key))
_sub_module.export_onnx_model(sub_dirname, export_sub_modules=export_sub_modules, **kwargs)
except:
utils.record_exception('Failed to export sub module {}'.format(_sub_module.name))
if not self._pretrained_model_path: if not self._pretrained_model_path:
if isinstance(self, paddle.nn.Layer): if isinstance(self, paddle.nn.Layer):
save_file = os.path.join(dirname, '{}'.format(self.name)) save_file = os.path.join(dirname, '{}'.format(self.name))
if hasattr(self, 'input_spec'): if not input_spec:
input_spec = self.input_sepc if hasattr(self, 'input_spec'):
else: input_spec = self.input_spec
_type = self.type.lower()
if _type.startswith('cv/image'):
input_spec = paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')
else: else:
raise NotImplementedError _type = self.type.lower()
paddle.onnx.export(self, save_file, input_spec=[input_spec]) if _type.startswith('cv/image'):
input_spec = [paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')]
else:
raise RuntimeError(
'Module {} lacks `input_spec`, please specify it when calling `export_onnx_model`.'.
format(self.name))
paddle.onnx.export(self, save_file, input_spec=input_spec, **kwargs)
return return
raise NotImplementedError
raise RuntimeError('Module {} does not support exporting models in ONNX format.'.format(self.name))
if not os.path.exists(self._pretrained_model_path):
log.logger.warning('The model path of Module {} does not exist'.format(self.name))
return
place = paddle.CPUPlace() place = paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
......
...@@ -407,6 +407,13 @@ class TransformerModule(RunModule, TextServing): ...@@ -407,6 +407,13 @@ class TransformerModule(RunModule, TextServing):
'text-matching', 'text-matching',
] ]
@property
def input_spec(self):
return [
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None, None], dtype='int64')
]
def _convert_text_to_input(self, tokenizer, texts: List[str], max_seq_len: int, split_char: str): def _convert_text_to_input(self, tokenizer, texts: List[str], max_seq_len: int, split_char: str):
pad_to_max_seq_len = False if self.task is None else True pad_to_max_seq_len = False if self.task is None else True
if self.task == 'token-cls': # Extra processing of token-cls task if self.task == 'token-cls': # Extra processing of token-cls task
...@@ -442,7 +449,7 @@ class TransformerModule(RunModule, TextServing): ...@@ -442,7 +449,7 @@ class TransformerModule(RunModule, TextServing):
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True)) pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
else: else:
raise RuntimeError( raise RuntimeError(
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text)) 'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(texts))
return encoded_inputs return encoded_inputs
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str): def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
...@@ -605,10 +612,9 @@ class TransformerModule(RunModule, TextServing): ...@@ -605,10 +612,9 @@ class TransformerModule(RunModule, TextServing):
results.extend(token_labels) results.extend(token_labels)
elif self.task == None: elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids) sequence_output, pooled_output = self(input_ids, segment_ids)
results.append([ results.append(
pooled_output.squeeze(0).numpy().tolist(), [pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist() sequence_output.squeeze(0).numpy().tolist()])
])
return results return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册