提交 b729e4ec 编写于 作者: A A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 448644197
上级 bce895d9
......@@ -61,6 +61,7 @@ class ExportModule(export_base.ExportModule):
preprocessor=preprocessor,
inference_step=inference_step,
postprocessor=postprocessor)
self.eval_postprocessor = eval_postprocessor
self.input_signature = input_signature
@tf.function
......
......@@ -15,7 +15,7 @@
r"""Vision models export utility function for serving/inference."""
import os
from typing import Optional, List
from typing import Optional, List, Union, Text, Dict
from absl import logging
import tensorflow as tf
......@@ -44,7 +44,8 @@ def export_inference_graph(
save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False,
checkpoint: Optional[tf.train.Checkpoint] = None,
input_name: Optional[str] = None):
input_name: Optional[str] = None,
function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None,):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
......@@ -72,6 +73,9 @@ def export_inference_graph(
will use it to read the weights.
input_name: The input tensor name, default at `None` which produces input
tensor name `inputs`.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
"""
if export_checkpoint_subdir:
......@@ -130,7 +134,7 @@ def export_inference_graph(
export_base.export(
export_module,
function_keys=[input_type],
function_keys=function_keys if function_keys else [input_type],
export_savedmodel_dir=output_saved_model_directory,
checkpoint=checkpoint,
checkpoint_path=checkpoint_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册