diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index faf58e0d93ff900ff822f968df6c894b7c055dc9..36fea36389dc15104cca8a0d421ba50906295e9a 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -848,7 +848,6 @@ py_test( deps = [ ":keras", "//tensorflow/python:client_testlib", - "//tensorflow/python/saved_model:save_test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py index 0250e604266679123a88fe781bb98439335dcf38..d8acec32cb65ffb2bbf517007802504e7c184544 100644 --- a/tensorflow/python/keras/engine/training_utils_test.py +++ b/tensorflow/python/keras/engine/training_utils_test.py @@ -22,10 +22,13 @@ import os import numpy as np + +from tensorflow.python.client import session as session_lib from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend as K @@ -35,8 +38,10 @@ from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import save as save_lib -from tensorflow.python.saved_model import save_test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants class ModelInputsTest(test.TestCase): @@ -222,6 +227,25 @@ class TraceModelCallTest(keras_parameterized.TestCase): self._assert_all_close(expected_outputs, signature_outputs) +def _import_and_infer(save_dir, inputs): + """Import a SavedModel into a TF 1.x-style graph and run `signature_key`.""" + graph = ops.Graph() + with graph.as_default(), session_lib.Session() as session: + model = loader.load(session, [tag_constants.SERVING], save_dir) + signature = model.signature_def[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + assert set(inputs.keys()) == set(signature.inputs.keys()) + feed_dict = {} + for arg_name in inputs.keys(): + feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = ( + inputs[arg_name]) + output_dict = {} + for output_name, output_tensor_info in signature.outputs.items(): + output_dict[output_name] = graph.get_tensor_by_name( + output_tensor_info.name) + return session.run(output_dict, feed_dict=feed_dict) + + class ModelSaveTest(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @@ -239,8 +263,7 @@ class ModelSaveTest(keras_parameterized.TestCase): self.assertAllClose( {model.output_names[0]: model.predict_on_batch(inputs)}, - save_test._import_and_infer(save_dir, - {model.input_names[0]: np.ones((8, 5))})) + _import_and_infer(save_dir, {model.input_names[0]: np.ones((8, 5))})) if __name__ == '__main__': test.main()