提交 662053a4 编写于 作者: K Katherine Wu 提交者: TensorFlower Gardener

Fix kokoro tests by removing dependency on save_test.

PiperOrigin-RevId: 224913339
上级 c136aa82
......@@ -848,7 +848,6 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
"//tensorflow/python/saved_model:save_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
......
......@@ -22,10 +22,13 @@ import os
import numpy as np
from tensorflow.python.client import session as session_lib
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
......@@ -35,8 +38,10 @@ from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import save_test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
class ModelInputsTest(test.TestCase):
......@@ -222,6 +227,25 @@ class TraceModelCallTest(keras_parameterized.TestCase):
self._assert_all_close(expected_outputs, signature_outputs)
def _import_and_infer(save_dir, inputs):
"""Import a SavedModel into a TF 1.x-style graph and run `signature_key`."""
graph = ops.Graph()
with graph.as_default(), session_lib.Session() as session:
model = loader.load(session, [tag_constants.SERVING], save_dir)
signature = model.signature_def[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
assert set(inputs.keys()) == set(signature.inputs.keys())
feed_dict = {}
for arg_name in inputs.keys():
feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = (
inputs[arg_name])
output_dict = {}
for output_name, output_tensor_info in signature.outputs.items():
output_dict[output_name] = graph.get_tensor_by_name(
output_tensor_info.name)
return session.run(output_dict, feed_dict=feed_dict)
class ModelSaveTest(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
......@@ -239,8 +263,7 @@ class ModelSaveTest(keras_parameterized.TestCase):
self.assertAllClose(
{model.output_names[0]: model.predict_on_batch(inputs)},
save_test._import_and_infer(save_dir,
{model.input_names[0]: np.ones((8, 5))}))
_import_and_infer(save_dir, {model.input_names[0]: np.ones((8, 5))}))
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册