提交 52dc7286 编写于 作者: S Shivani Agrawal 提交者: TensorFlower Gardener

[Checkpointable] Make Iterator checkpointable.

Use object-based save/restore to make dataset/iterator checkpointable in both graph as well as eager mode.

PiperOrigin-RevId: 206998349
上级 a28ad4b2
......@@ -22,12 +22,9 @@ from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
class Iterator(iterator_ops.EagerIterator):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset.
NOTE: Unlike the iterator created by the
......@@ -82,30 +79,3 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
return super(Iterator, self)._next_internal()
# TODO(shivaniagrawal): Expose checkpointable stateful objects from dataset
# attributes(potential).
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource, name):
serialized_iterator = gen_dataset_ops.serialize_iterator(
iterator_resource)
specs = [
BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
]
# pylint: disable=protected-access
super(Iterator._Saveable, self).__init__(iterator_resource, specs, name)
def restore(self, restored_tensors, restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op,
restored_tensors[0])
def _gather_saveables_for_checkpoint(self):
def _saveable_factory(name):
return self._Saveable(self._resource, name)
return {"ITERATOR": _saveable_factory}
......@@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import util as checkpointable_utils
......@@ -306,6 +307,18 @@ class IteratorTest(test.TestCase):
checkpoint.restore(save_path)
self.assertEqual(2, iterator.get_next().numpy())
def testRestoreInReconstructedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
dataset = Dataset.range(10)
for i in range(5):
iterator = datasets.Iterator(dataset)
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
checkpoint.restore(saver.latest_checkpoint(checkpoint_directory))
for j in range(2):
self.assertEqual(i * 2 + j, iterator.get_next().numpy())
checkpoint.save(file_prefix=checkpoint_prefix)
class DatasetConstructorBenchmark(test.Benchmark):
......
......@@ -3216,6 +3216,7 @@ py_library(
# The following targets have their own build rules (same name as the
# file):
"training/saveable_object.py",
"training/saver.py",
"training/training_util.py",
],
),
......@@ -3247,6 +3248,7 @@ py_library(
":random_ops",
":resource_variable_ops",
":resources",
"saver",
":saveable_object",
":sdca_ops",
":sparse_ops",
......@@ -3277,6 +3279,40 @@ py_library(
srcs_version = "PY2AND3",
)
py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":constant_op",
":control_flow_ops",
":device",
":errors",
":framework",
":framework_ops",
":io_ops",
":io_ops_gen",
":lib",
":platform",
":protos_all_py",
":pywrap_tensorflow",
":resource_variable_ops",
":saveable_object",
":session",
":state_ops",
":string_ops",
":training_util",
":util",
":variables",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
py_library(
name = "device_util",
srcs = ["training/device_util.py"],
......
......@@ -329,6 +329,8 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
......@@ -350,6 +352,8 @@ cuda_py_test(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
"//tensorflow/python/compat:compat",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
grpc_enabled = True,
)
......
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import warnings
......@@ -46,7 +47,9 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
......@@ -788,5 +791,98 @@ class IteratorTest(test.TestCase):
val += 1
class IteratorCheckpointingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreOneShotIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
math_ops.square).batch(2)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
with self.test_session() as sess:
self.assertAllEqual([1, 4], get_next())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual([9, 16], get_next())
self.assertAllEqual([25, 36], get_next())
checkpoint.restore(save_path).run_restore_ops(sess)
self.assertAllEqual([9, 16], get_next())
self.assertAllEqual([25, 36], get_next())
with self.assertRaises(errors.OutOfRangeError):
get_next()
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreMultipleIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.from_tensor_slices(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
dataset = dataset.map(math_ops.square).batch(2)
iterator_1 = dataset.make_one_shot_iterator()
get_next_1 = iterator_1.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator_1.get_next())
iterator_2 = dataset.make_one_shot_iterator()
get_next_2 = iterator_2.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator_2.get_next())
dataset_2 = dataset_ops.Dataset.range(10)
iterator_3 = dataset_2.make_one_shot_iterator()
get_next_3 = iterator_3.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator_3.get_next())
checkpoint = checkpointable_utils.Checkpoint(
iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
with self.test_session() as sess:
self.assertAllEqual([1, 4], get_next_1())
self.assertAllEqual(0, get_next_3())
self.assertAllEqual(1, get_next_3())
self.assertAllEqual(2, get_next_3())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual([1, 4], get_next_2())
self.assertAllEqual([9, 16], get_next_2())
self.assertAllEqual(3, get_next_3())
checkpoint.restore(save_path).run_restore_ops(sess)
self.assertAllEqual([9, 16], get_next_1())
self.assertAllEqual([1, 4], get_next_2())
self.assertAllEqual(3, get_next_3())
@test_util.run_in_graph_and_eager_modes
def testRestoreExhaustedIterator(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.range(3)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next if context.executing_eagerly(
) else functools.partial(self.evaluate, iterator.get_next())
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
with self.test_session() as sess:
self.assertAllEqual(0, get_next())
self.assertAllEqual(1, get_next())
save_path = checkpoint.save(checkpoint_prefix)
self.assertAllEqual(2, get_next())
checkpoint.restore(save_path).run_restore_ops(sess)
self.assertAllEqual(2, get_next())
save_path = checkpoint.save(checkpoint_prefix)
checkpoint.restore(save_path).run_restore_ops(sess)
with self.assertRaises(errors.OutOfRangeError):
get_next()
def testRestoreInReconstructedIteratorInitializable(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = dataset_ops.Dataset.range(10)
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
for i in range(5):
with self.test_session() as sess:
checkpoint.restore(saver.latest_checkpoint(
checkpoint_directory)).initialize_or_restore(sess)
for j in range(2):
self.assertEqual(i * 2 + j, sess.run(get_next))
checkpoint.save(file_prefix=checkpoint_prefix)
if __name__ == "__main__":
test.main()
......@@ -54,10 +54,12 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:saver",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
],
)
......@@ -30,6 +30,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util.tf_export import tf_export
......@@ -65,7 +67,7 @@ def _device_stack_is_empty():
@tf_export("data.Iterator")
class Iterator(object):
class Iterator(checkpointable.CheckpointableBase):
"""Represents the state of iterating through a `Dataset`."""
def __init__(self, iterator_resource, initializer, output_types,
......@@ -464,6 +466,13 @@ class Iterator(object):
"""
return self._output_types
def _gather_saveables_for_checkpoint(self):
def _saveable_factory(name):
return _IteratorSaveable(self._iterator_resource, name)
return {"ITERATOR": _saveable_factory}
_uid_counter = 0
_uid_lock = threading.Lock()
......@@ -477,7 +486,7 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
class EagerIterator(object):
class EagerIterator(checkpointable.CheckpointableBase):
"""An iterator producing tf.Tensor objects from a tf.data.Dataset."""
def __init__(self, dataset):
......@@ -610,3 +619,28 @@ class EagerIterator(object):
"""
del name
return self._next_internal()
def _gather_saveables_for_checkpoint(self):
def _saveable_factory(name):
return _IteratorSaveable(self._resource, name)
return {"ITERATOR": _saveable_factory}
# TODO(b/71645805): Expose checkpointable stateful objects from dataset
# attributes(potential).
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource, name):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
BaseSaverBuilder.SaveSpec(serialized_iterator, "", name + "_STATE")
]
# pylint: disable=protected-access
super(_IteratorSaveable, self).__init__(iterator_resource, specs, name)
def restore(self, restored_tensors, restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
path: "tensorflow.data.Iterator"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
name: "initializer"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册