From de5705ecbe6bd237e313fe8b65f097e70ae8a771 Mon Sep 17 00:00:00 2001 From: Vojtech Bardiovsky Date: Thu, 19 Dec 2019 06:53:31 -0800 Subject: [PATCH] Fix saved_model_cli when trying to print pure ConcreteFunction arguments. PiperOrigin-RevId: 286384508 Change-Id: I67a3997ba67a7b474a9e1157f6f3fb12dbe84007 --- tensorflow/python/tools/saved_model_cli.py | 33 ++++++--- .../python/tools/saved_model_cli_test.py | 69 ++++++++++++++++++- 2 files changed, 92 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 4cc9cabec4f..15ee6d1909e 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -31,12 +31,14 @@ import sys import warnings import numpy as np -from six import integer_types +import six from tensorflow.core.example import example_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.debug.wrappers import local_cli_wrapper +from tensorflow.python.eager import def_function +from tensorflow.python.eager import function as defun from tensorflow.python.framework import meta_graph as meta_graph_lib from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_spec @@ -182,14 +184,25 @@ def _show_defined_functions(saved_model_dir): functions = sorted(functions.items(), key=lambda x: x[0]) for name, function in functions: print('\n Function Name: \'%s\'' % name) - concrete_functions = \ - function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access + concrete_functions = [] + if isinstance(function, defun.ConcreteFunction): + concrete_functions.append(function) + if isinstance(function, def_function.Function): + concrete_functions.extend( + function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access concrete_functions = sorted(concrete_functions, key=lambda x: x.name) for index, concrete_function in enumerate(concrete_functions, 1): - args, kwargs = concrete_function.structured_input_signature - print(' Option #%d' % index) - print(' Callable with:') - _print_args(args, indent=4) + args, kwargs = None, None + if concrete_function.structured_input_signature: + args, kwargs = concrete_function.structured_input_signature + elif concrete_function._arg_keywords: # pylint: disable=protected-access + # For pure ConcreteFunctions we might have nothing better than + # _arg_keywords. + args = concrete_function._arg_keywords # pylint: disable=protected-access + if args: + print(' Option #%d' % index) + print(' Callable with:') + _print_args(args, indent=4) if kwargs: _print_args(kwargs, 'Named Argument', indent=4) @@ -215,7 +228,9 @@ def _print_args(arguments, argument_type='Argument', indent=0): for index, element in enumerate(arguments, 1): if indent == 4: in_print('%s #%d' % (argument_type, index)) - if isinstance(element, tensor_spec.TensorSpec): + if isinstance(element, six.string_types): + in_print(' %s' % element) + elif isinstance(element, tensor_spec.TensorSpec): print((indent + 1) * ' ' + '%s: %s' % (element.name, repr(element))) elif (isinstance(element, collections.Iterable) and not isinstance(element, dict)): @@ -567,7 +582,7 @@ def _create_example_string(example_dict): elif isinstance(feature_list[0], str): example.features.feature[feature_name].bytes_list.value.extend( feature_list) - elif isinstance(feature_list[0], integer_types): + elif isinstance(feature_list[0], six.integer_types): example.features.feature[feature_name].int64_list.value.extend( feature_list) else: diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index 488bce93974..74acbf82d56 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -148,7 +148,7 @@ signature_def['serving_default']: self.assertMultiLineEqual(output, exp_out) self.assertEqual(err.getvalue().strip(), '') - def testShowAllWithConcreteFunctions(self): + def testShowAllWithFunctions(self): class DummyModel(tracking.AutoTrackable): """Model with callable polymorphic functions specified.""" @@ -237,6 +237,73 @@ Defined Functions: self.assertMultiLineEqual(output, exp_out) self.assertEqual(err.getvalue().strip(), '') + def testShowAllWithPureConcreteFunction(self): + + class DummyModel(tracking.AutoTrackable): + """Model with a callable concrete function.""" + + def __init__(self): + function = def_function.function( + self.multiply, + input_signature=[ + tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32) + ]) + self.pure_concrete_function = function.get_concrete_function() + super(DummyModel, self).__init__() + + def multiply(self, a, b): + return a * b + + saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') + dummy_model = DummyModel() + save.save(dummy_model, saved_model_dir) + self.parser = saved_model_cli.create_parser() + args = self.parser.parse_args(['show', '--dir', saved_model_dir, '--all']) + with captured_output() as (out, err): + saved_model_cli.show(args) + output = out.getvalue().strip() + exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: + +signature_def['__saved_model_init_op']: + The given SavedModel SignatureDef contains the following input(s): + The given SavedModel SignatureDef contains the following output(s): + outputs['__saved_model_init_op'] tensor_info: + dtype: DT_INVALID + shape: unknown_rank + name: NoOp + Method name is: + +signature_def['serving_default']: + The given SavedModel SignatureDef contains the following input(s): + inputs['a'] tensor_info: + dtype: DT_FLOAT + shape: () + name: serving_default_a:0 + inputs['b'] tensor_info: + dtype: DT_FLOAT + shape: () + name: serving_default_b:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['output_0'] tensor_info: + dtype: DT_FLOAT + shape: () + name: PartitionedCall:0 + Method name is: tensorflow/serving/predict + +Defined Functions: + Function Name: 'pure_concrete_function' + Option #1 + Callable with: + Argument #1 + a: TensorSpec(shape=(), dtype=tf.float32, name='a') + Argument #2 + b: TensorSpec(shape=(), dtype=tf.float32, name='b') +""".strip() # pylint: enable=line-too-long + self.maxDiff = None # Produce a useful error msg if the comparison fails + self.assertMultiLineEqual(output, exp_out) + self.assertEqual(err.getvalue().strip(), '') + def testShowCommandTags(self): base_path = test.test_src_dir_path(SAVED_MODEL_PATH) self.parser = saved_model_cli.create_parser() -- GitLab