提交 de5705ec 编写于 作者: V Vojtech Bardiovsky 提交者: TensorFlower Gardener

Fix saved_model_cli when trying to print pure ConcreteFunction arguments.

PiperOrigin-RevId: 286384508
Change-Id: I67a3997ba67a7b474a9e1157f6f3fb12dbe84007
上级 6ce4169f
......@@ -31,12 +31,14 @@ import sys
import warnings
import numpy as np
from six import integer_types
import six
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_spec
......@@ -182,14 +184,25 @@ def _show_defined_functions(saved_model_dir):
functions = sorted(functions.items(), key=lambda x: x[0])
for name, function in functions:
print('\n Function Name: \'%s\'' % name)
concrete_functions = \
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
concrete_functions = []
if isinstance(function, defun.ConcreteFunction):
concrete_functions.append(function)
if isinstance(function, def_function.Function):
concrete_functions.extend(
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
concrete_functions = sorted(concrete_functions, key=lambda x: x.name)
for index, concrete_function in enumerate(concrete_functions, 1):
args, kwargs = concrete_function.structured_input_signature
print(' Option #%d' % index)
print(' Callable with:')
_print_args(args, indent=4)
args, kwargs = None, None
if concrete_function.structured_input_signature:
args, kwargs = concrete_function.structured_input_signature
elif concrete_function._arg_keywords: # pylint: disable=protected-access
# For pure ConcreteFunctions we might have nothing better than
# _arg_keywords.
args = concrete_function._arg_keywords # pylint: disable=protected-access
if args:
print(' Option #%d' % index)
print(' Callable with:')
_print_args(args, indent=4)
if kwargs:
_print_args(kwargs, 'Named Argument', indent=4)
......@@ -215,7 +228,9 @@ def _print_args(arguments, argument_type='Argument', indent=0):
for index, element in enumerate(arguments, 1):
if indent == 4:
in_print('%s #%d' % (argument_type, index))
if isinstance(element, tensor_spec.TensorSpec):
if isinstance(element, six.string_types):
in_print(' %s' % element)
elif isinstance(element, tensor_spec.TensorSpec):
print((indent + 1) * ' ' + '%s: %s' % (element.name, repr(element)))
elif (isinstance(element, collections.Iterable) and
not isinstance(element, dict)):
......@@ -567,7 +582,7 @@ def _create_example_string(example_dict):
elif isinstance(feature_list[0], str):
example.features.feature[feature_name].bytes_list.value.extend(
feature_list)
elif isinstance(feature_list[0], integer_types):
elif isinstance(feature_list[0], six.integer_types):
example.features.feature[feature_name].int64_list.value.extend(
feature_list)
else:
......
......@@ -148,7 +148,7 @@ signature_def['serving_default']:
self.assertMultiLineEqual(output, exp_out)
self.assertEqual(err.getvalue().strip(), '')
def testShowAllWithConcreteFunctions(self):
def testShowAllWithFunctions(self):
class DummyModel(tracking.AutoTrackable):
"""Model with callable polymorphic functions specified."""
......@@ -237,6 +237,73 @@ Defined Functions:
self.assertMultiLineEqual(output, exp_out)
self.assertEqual(err.getvalue().strip(), '')
def testShowAllWithPureConcreteFunction(self):
class DummyModel(tracking.AutoTrackable):
"""Model with a callable concrete function."""
def __init__(self):
function = def_function.function(
self.multiply,
input_signature=[
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
])
self.pure_concrete_function = function.get_concrete_function()
super(DummyModel, self).__init__()
def multiply(self, a, b):
return a * b
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = DummyModel()
save.save(dummy_model, saved_model_dir)
self.parser = saved_model_cli.create_parser()
args = self.parser.parse_args(['show', '--dir', saved_model_dir, '--all'])
with captured_output() as (out, err):
saved_model_cli.show(args)
output = out.getvalue().strip()
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['__saved_model_init_op']:
The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['__saved_model_init_op'] tensor_info:
dtype: DT_INVALID
shape: unknown_rank
name: NoOp
Method name is:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['a'] tensor_info:
dtype: DT_FLOAT
shape: ()
name: serving_default_a:0
inputs['b'] tensor_info:
dtype: DT_FLOAT
shape: ()
name: serving_default_b:0
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: ()
name: PartitionedCall:0
Method name is: tensorflow/serving/predict
Defined Functions:
Function Name: 'pure_concrete_function'
Option #1
Callable with:
Argument #1
a: TensorSpec(shape=(), dtype=tf.float32, name='a')
Argument #2
b: TensorSpec(shape=(), dtype=tf.float32, name='b')
""".strip() # pylint: enable=line-too-long
self.maxDiff = None # Produce a useful error msg if the comparison fails
self.assertMultiLineEqual(output, exp_out)
self.assertEqual(err.getvalue().strip(), '')
def testShowCommandTags(self):
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
self.parser = saved_model_cli.create_parser()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册