提交 fe10b94f 编写于 作者: W wuzewu

Improve the model export interface.

上级 23493590
......@@ -282,6 +282,7 @@ class ModuleV1(object):
model_filename=model_filename,
params_filename=params_filename)
@paddle_utils.run_in_static_mode
def export_onnx_model(self, dirname: str, **kwargs):
'''
Export the model to ONNX format.
......@@ -289,8 +290,8 @@ class ModuleV1(object):
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
feed_dict, fetch_dict, program = self.context(for_test=True, trainable=False)
inputs = set([var.name for var in feed_dict.values()])
......@@ -308,3 +309,12 @@ class ModuleV1(object):
target_vars=outputs,
save_file=save_file,
**kwargs)
def sub_modules(self, recursive: bool = True):
'''
Get all sub modules.
Args:
recursive(bool): Whether to get sub modules recursively. Default to True.
'''
return []
......@@ -135,7 +135,7 @@ class RunModule(object):
@property
def _pretrained_model_path(self):
_pretrained_model_attrs = [
'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path'
'pretrained_model_path', 'rec_pretrained_model_path', 'default_pretrained_model_path', 'model_path'
]
for _attr in _pretrained_model_attrs:
......@@ -147,30 +147,77 @@ class RunModule(object):
return None
def export_onnx_model(self, dirname: str, **kwargs):
def sub_modules(self, recursive: bool = True):
'''
Get all sub modules.
Args:
recursive(bool): Whether to get sub modules recursively. Default to True.
'''
_sub_modules = {}
for key, item in self.__dict__.items():
if id(item) == id(self):
continue
if isinstance(item, (RunModule, ModuleV1)):
_sub_modules[key] = item
if not recursive:
continue
for _k, _v in item.sub_modules(recursive):
_sub_modules['{}/{}'.format(key, _k)] = _v
return _sub_modules
def export_onnx_model(self,
dirname: str,
input_spec: List[paddle.static.InputSpec] = None,
export_sub_modules: bool = True,
**kwargs):
'''
Export the model to ONNX format.
Args:
dirname(str): The directory to save the onnx model.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed
in the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
input_spec(list): Describes the input of the saved model's forward method, which can be described by
InputSpec or example Tensor. If None, all input variables of the original Layer's forward method
would be the inputs of the saved model. Default None.
export_sub_modules(bool): Whether to export sub modules. Default to True.
**kwargs(dict|optional): Other export configuration options for compatibility, some may be removed in
the future. Don't use them If not necessary. Refer to https://github.com/PaddlePaddle/paddle2onnx
for more information.
'''
if export_sub_modules:
for key, _sub_module in self.sub_modules().items():
try:
sub_dirname = os.path.normpath(os.path.join(dirname, key))
_sub_module.export_onnx_model(sub_dirname, export_sub_modules=export_sub_modules, **kwargs)
except:
utils.record_exception('Failed to export sub module {}'.format(_sub_module.name))
if not self._pretrained_model_path:
if isinstance(self, paddle.nn.Layer):
save_file = os.path.join(dirname, '{}'.format(self.name))
if hasattr(self, 'input_spec'):
input_spec = self.input_sepc
else:
_type = self.type.lower()
if _type.startswith('cv/image'):
input_spec = paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')
if not input_spec:
if hasattr(self, 'input_spec'):
input_spec = self.input_spec
else:
raise NotImplementedError
paddle.onnx.export(self, save_file, input_spec=[input_spec])
_type = self.type.lower()
if _type.startswith('cv/image'):
input_spec = [paddle.static.InputSpec(shape=[None, 3, None, None], dtype='float32')]
else:
raise RuntimeError(
'Module {} lacks `input_spec`, please specify it when calling `export_onnx_model`.'.
format(self.name))
paddle.onnx.export(self, save_file, input_spec=input_spec, **kwargs)
return
raise NotImplementedError
raise RuntimeError('Module {} does not support exporting models in ONNX format.'.format(self.name))
if not os.path.exists(self._pretrained_model_path):
log.logger.warning('The model path of Module {} does not exist'.format(self.name))
return
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
......
......@@ -407,6 +407,13 @@ class TransformerModule(RunModule, TextServing):
'text-matching',
]
@property
def input_spec(self):
return [
paddle.static.InputSpec(shape=[None, None], dtype='int64'),
paddle.static.InputSpec(shape=[None, None], dtype='int64')
]
def _convert_text_to_input(self, tokenizer, texts: List[str], max_seq_len: int, split_char: str):
pad_to_max_seq_len = False if self.task is None else True
if self.task == 'token-cls': # Extra processing of token-cls task
......@@ -442,7 +449,7 @@ class TransformerModule(RunModule, TextServing):
pad_to_max_seq_len=True, is_split_into_words=is_split_into_words, return_length=True))
else:
raise RuntimeError(
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(text))
'The input text must have one or two sequence, but got %d. Please check your inputs.' % len(texts))
return encoded_inputs
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int, split_char: str):
......@@ -605,10 +612,9 @@ class TransformerModule(RunModule, TextServing):
results.extend(token_labels)
elif self.task == None:
sequence_output, pooled_output = self(input_ids, segment_ids)
results.append([
pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()
])
results.append(
[pooled_output.squeeze(0).numpy().tolist(),
sequence_output.squeeze(0).numpy().tolist()])
return results
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册