提交 19ef8215 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Make SavedModel exports include all the SAVEABLE objects and not just global variables.

Change: 150243023
上级 b05e0840
......@@ -58,6 +58,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
......@@ -1254,13 +1255,17 @@ class Estimator(BaseEstimator):
with tf_session.Session('') as session:
variables.initialize_local_variables()
data_flow_ops.tables_initializer()
resources.initialize_resources(resources.shared_resources())
saver_for_restore = saver.Saver(
variables.global_variables(),
# pylint: disable=protected-access
variables._all_saveable_objects(),
# pylint: enable=protected-access
sharded=True)
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
data_flow_ops.tables_initializer())
# Perform the export
......
......@@ -50,6 +50,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
......@@ -225,6 +226,49 @@ def _build_estimator_for_export_tests(tmpdir):
return est, serving_input_fn_with_asset
def _build_estimator_for_resource_export_test():
def _input_fn():
iris = base.load_iris()
return {
'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
}, constant_op.constant(
iris.target, shape=[150], dtype=dtypes.int32)
feature_columns = [
feature_column_lib.real_valued_column('feature', dimension=4)
]
def resource_constant_model_fn(unused_features, unused_labels, mode):
"""A model_fn that loads a constant from a resource and serves it."""
assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
model_fn.ModeKeys.INFER)
const = constant_op.constant(-1, dtype=dtypes.int64)
table = lookup.MutableHashTable(
dtypes.string, dtypes.int64, const, name='LookupTableModel')
if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL):
key = constant_op.constant(['key'])
value = constant_op.constant([42], dtype=dtypes.int64)
train_op_1 = table.insert(key, value)
training_state = lookup.MutableHashTable(
dtypes.string, dtypes.int64, const, name='LookupTableTrainingState')
training_op_2 = training_state.insert(key, value)
return const, const, control_flow_ops.group(train_op_1, training_op_2)
if mode == model_fn.ModeKeys.INFER:
key = constant_op.constant(['key'])
prediction = table.lookup(key)
return prediction, const, control_flow_ops.no_op()
est = estimator.Estimator(model_fn=resource_constant_model_fn)
est.fit(input_fn=_input_fn, steps=1)
feature_spec = feature_column_lib.create_feature_spec_for_parsing(
feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
return est, serving_input_fn
class CheckCallsMonitor(monitors_lib.BaseMonitor):
def __init__(self, expect_calls):
......@@ -753,6 +797,49 @@ class EstimatorTest(test.TestCase):
# cleanup
gfile.DeleteRecursively(tmpdir)
def test_export_savedmodel_with_resource(self):
tmpdir = tempfile.mkdtemp()
est, serving_input_fn = _build_estimator_for_resource_export_test()
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
export_dir = est.export_savedmodel(export_dir_base, serving_input_fn)
self.assertTrue(gfile.Exists(export_dir_base))
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(
'saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes('variables'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.index'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
# Restore, to validate that the export was well-formed.
with ops.Graph().as_default() as graph:
with session_lib.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('LookupTableModel' in graph_ops)
self.assertFalse('LookupTableTrainingState' in graph_ops)
# cleanup
gfile.DeleteRecursively(tmpdir)
class InferRealValuedColumnsTest(test.TestCase):
......
......@@ -74,6 +74,7 @@ py_library(
":tensor_array_ops",
":training",
":ops",
":saver_test_utils",
":test_ops", # TODO: Break testing code out into separate rule.
":util",
":weights_broadcast_ops",
......@@ -2935,6 +2936,16 @@ cuda_py_tests(
],
)
py_library(
name = "saver_test_utils",
srcs = ["training/saver_test_utils.py"],
srcs_version = "PY2AND3",
deps = [
":data_flow_ops_gen",
":training",
],
)
cuda_py_test(
name = "saver_test",
size = "medium",
......@@ -2946,12 +2957,12 @@ cuda_py_test(
":client_testlib",
":control_flow_ops",
":data_flow_ops",
":data_flow_ops_gen",
":errors",
":gradients",
":math_ops",
":nn_grad",
":nn_ops",
":saver_test_utils",
":partitioned_variables",
":platform",
":platform_test",
......
......@@ -136,6 +136,7 @@ py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:layers",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
......
......@@ -411,7 +411,7 @@ class Estimator(object):
with tf_session.Session() as session:
saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
variables.global_variables(),
variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True)
saver_for_restore.restore(session, checkpoint_path)
......
......@@ -48,6 +48,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
......@@ -814,6 +815,20 @@ def _model_fn_for_export_tests(features, labels, mode):
'test': export_output.ClassificationOutput(scores, classes)})
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
_, _ = features, labels
table = saver_test_utils.CheckpointedOp(name='v2')
train_op = table.insert('k1', 30.0)
prediction = table.lookup('k1', 0.0)
return model_fn_lib.EstimatorSpec(
mode,
predictions=prediction,
loss=constant_op.constant(1.),
train_op=train_op,
export_outputs={
'test': export_output.PredictOutput({'prediction': prediction})})
_VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
_EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
......@@ -863,6 +878,50 @@ class EstimatorExportTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
model_fn=_model_fn_with_saveables_for_export_tests)
est.train(input_fn=dummy_input_fn, steps=1)
feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
# Perform the export.
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
export_dir = est.export_savedmodel(
export_dir_base, serving_input_receiver_fn)
# Check that all the files are in the right places.
self.assertTrue(gfile.Exists(export_dir_base))
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(gfile.Exists(os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('saved_model.pb'))))
self.assertTrue(gfile.Exists(os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables'))))
self.assertTrue(gfile.Exists(os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.index'))))
self.assertTrue(gfile.Exists(os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
# Restore, to validate that the export was well-formed.
with ops.Graph().as_default() as graph:
with session.Session(graph=graph) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('save/LookupTableImport' in graph_ops)
# Clean up.
gfile.DeleteRecursively(tmpdir)
def test_export_savedmodel_assets(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
......
......@@ -122,6 +122,7 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
"//tensorflow/python:util",
"//tensorflow/python:variables",
......
......@@ -352,10 +352,10 @@ class SavedModelBuilder(object):
else:
self._add_main_op(main_op)
# Initialize a saver to generate a sharded output for all variables in the
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
variables.global_variables(),
variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True,
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
......@@ -423,10 +423,10 @@ class SavedModelBuilder(object):
else:
self._add_main_op(main_op)
# Initialize a saver to generate a sharded output for all variables in the
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
variables.global_variables(),
variables._all_saveable_objects(), # pylint: disable=protected-access
sharded=True,
write_version=saver_pb2.SaverDef.V2,
allow_empty=True)
......
......@@ -39,6 +39,7 @@ from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver_test_utils
from tensorflow.python.util import compat
SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
......@@ -734,6 +735,35 @@ class SavedModelTest(test.TestCase):
ops.get_collection("init_op")[0].run()
self.assertEqual(3, ops.get_collection("v")[2].eval())
def testCustomSaveable(self):
export_dir = os.path.join(test.get_temp_dir(), "custom_saveable")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# CheckpointedOp is a key-value table that can be saved across sessions.
# The table register itself in SAVEABLE_OBJECTS collection.
v1 = saver_test_utils.CheckpointedOp(name="v1")
variables.global_variables_initializer().run()
v1.insert("k1", 3.0).run()
# Once the table is restored, we can access it through this reference.
ops.add_to_collection("table_ref", v1.table_ref)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
loader.load(sess, ["foo"], export_dir)
# Instantiate a wrapper object from the checkpointed reference.
v1 = saver_test_utils.CheckpointedOp(
name="v1", table_ref=ops.get_collection("table_ref")[0])
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
def testClearDevices(self):
export_dir = os.path.join(test.get_temp_dir(), "test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
......
......@@ -48,7 +48,6 @@ from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
......@@ -65,63 +64,10 @@ from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.util import compat
class CheckpointedOp(object):
"""Op with a custom checkpointing implementation.
Defined as part of the test because the MutableHashTable Python code is
currently in contrib.
"""
def __init__(self, name):
self._table_ref = gen_data_flow_ops._mutable_hash_table(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
self._name = name
self._saveable = CheckpointedOp.CustomSaveable(self, name)
ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
self._saveable)
@property
def name(self):
return self._name
@property
def saveable(self):
return self._saveable
def insert(self, keys, values):
return gen_data_flow_ops._lookup_table_insert(self._table_ref, keys, values)
def keys(self):
return self._export()[0]
def values(self):
return self._export()[1]
def _export(self):
return gen_data_flow_ops._lookup_table_export(self._table_ref,
dtypes.string, dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
def __init__(self, table, name):
tensors = table._export()
specs = [
saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
name + "-keys"),
saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
name + "-values")
]
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
return gen_data_flow_ops._lookup_table_import(self.op._table_ref,
restore_tensors[0],
restore_tensors[1])
class SaverTest(test.TestCase):
def basicSaveRestore(self, variable_op):
......@@ -131,7 +77,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0 = variable_op(10.0, name="v0")
v1 = variable_op(20.0, name="v1")
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
{
......@@ -161,7 +107,7 @@ class SaverTest(test.TestCase):
with self.test_session() as sess:
v0 = variable_op(-1.0, name="v0")
v1 = variable_op(-1.0, name="v1")
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
# Assert that the variables are not initialized.
......@@ -183,7 +129,7 @@ class SaverTest(test.TestCase):
with self.test_session() as sess:
v0_2 = variable_op(1000.0, name="v0")
v1_2 = variable_op(2000.0, name="v1")
v2_2 = CheckpointedOp(name="v2")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
v2_2.insert("k1000", 3000.0).run()
variables.global_variables_initializer().run()
......@@ -276,7 +222,7 @@ class SaverTest(test.TestCase):
def testSameName(self):
with ops_lib.Graph().as_default():
v0 = variables.Variable([10.0], name="v0")
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
with self.assertRaisesRegexp(
......@@ -299,7 +245,7 @@ class SaverTest(test.TestCase):
# Restore nodes for them.
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(20.0, name="v1")
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
variables.global_variables_initializer().run()
......@@ -321,7 +267,7 @@ class SaverTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
v0 = variables.Variable(-1.0, name="v0")
v1 = variables.Variable(-1.0, name="v1")
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
with self.assertRaisesWithPredicateMatch(
......@@ -346,7 +292,7 @@ class SaverTest(test.TestCase):
with self.test_session(graph=ops_lib.Graph()) as sess:
v0_2 = variables.Variable(1000.0, name="v0")
v1_2 = variables.Variable(2000.0, name="v1")
v2_2 = CheckpointedOp(name="v2")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
variables.global_variables_initializer().run()
......@@ -418,7 +364,7 @@ class SaverTest(test.TestCase):
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.Variable(1.0)
twos = variables.Variable([2.0, 2.0, 2.0])
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
......@@ -428,7 +374,7 @@ class SaverTest(test.TestCase):
with session.Session("", graph=ops_lib.Graph()) as sess:
one = variables.Variable(0.0)
twos = variables.Variable([0.0, 0.0, 0.0])
v2 = CheckpointedOp(name="v2")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
......@@ -593,10 +539,10 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(10, name="v0")
t0 = CheckpointedOp(name="t0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.Variable(20, name="v1")
t1 = CheckpointedOp(name="t1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
......@@ -623,7 +569,7 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(111, name="v0")
t0 = CheckpointedOp(name="t0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver({"v0": v0, "t0": t0.saveable}, sharded=True)
variables.global_variables_initializer().run()
t0.insert("k11", 33.0).run()
......@@ -641,7 +587,7 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v1 = variables.Variable(222)
t1 = CheckpointedOp(name="t1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver({"v1": v1, "t1": t1.saveable}, sharded=True)
variables.global_variables_initializer().run()
t1.insert("k22", 44.0).run()
......@@ -659,10 +605,10 @@ class SaveRestoreShardedTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(111, name="v0")
t0 = CheckpointedOp(name="t0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
v1 = variables.Variable(222, name="v1")
t1 = CheckpointedOp(name="t1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
"v0": v0,
......
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Utility classes for testing checkpointing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.training import saver as saver_module
class CheckpointedOp(object):
"""Op with a custom checkpointing implementation.
Defined as part of the test because the MutableHashTable Python code is
currently in contrib.
"""
# pylint: disable=protected-access
def __init__(self, name, table_ref=None):
if table_ref is None:
self.table_ref = gen_data_flow_ops._mutable_hash_table(
key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
else:
self.table_ref = table_ref
self._name = name
self._saveable = CheckpointedOp.CustomSaveable(self, name)
ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
self._saveable)
@property
def name(self):
return self._name
@property
def saveable(self):
return self._saveable
def insert(self, keys, values):
return gen_data_flow_ops._lookup_table_insert(self.table_ref, keys, values)
def lookup(self, keys, default):
return gen_data_flow_ops._lookup_table_find(self.table_ref, keys, default)
def keys(self):
return self._export()[0]
def values(self):
return self._export()[1]
def _export(self):
return gen_data_flow_ops._lookup_table_export(self.table_ref, dtypes.string,
dtypes.float32)
class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
"""A custom saveable for CheckpointedOp."""
def __init__(self, table, name):
tensors = table._export()
specs = [
saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
name + "-keys"),
saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
name + "-values")
]
super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
def restore(self, restore_tensors, shapes):
return gen_data_flow_ops._lookup_table_import(
self.op.table_ref, restore_tensors[0], restore_tensors[1])
# pylint: enable=protected-access
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册