提交 172bc601 编写于 作者: R Rohan Jain 提交者: TensorFlower Gardener

Handle ExternalStatePolicy during Iterator::Save(). This CL adds an optional...

Handle ExternalStatePolicy during Iterator::Save(). This CL adds an optional external_state_policy attr to the SerializeIterator op which is by default set to FAIL thereby keeping the behavior the same as before. Now users can override it if they want something else. This CL also exposes this option in the CheckpointInputPipelineHook.

PiperOrigin-RevId: 286314588
Change-Id: I4f8eae88f42c74b168348afca8b295eb37d845dc
上级 ec23b813
......@@ -887,7 +887,17 @@ class DatasetBaseIterator : public IteratorBase {
}
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
TF_RETURN_IF_ERROR(params_.dataset->CheckExternalState());
Status s = params_.dataset->CheckExternalState();
if (!s.ok()) {
if (ctx->external_state_policy() ==
SerializationContext::ExternalStatePolicy::kWarn) {
LOG(WARNING) << "Dataset contains external state: " << s.ToString();
}
if (ctx->external_state_policy() ==
SerializationContext::ExternalStatePolicy::kFail) {
return s;
}
}
return IteratorBase::Save(ctx, writer);
}
......
......@@ -1115,6 +1115,7 @@ tf_kernel_library(
hdrs = ["iterator_ops.h"],
deps = [
":captured_function",
":dataset_ops",
":dataset_utils",
":optional_ops",
":unbounded_thread_pool",
......
......@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_ops.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
......@@ -294,10 +295,10 @@ class IteratorVariantSerializer {
// Calls `Save` on the iterator_resource to build up the list of
// IteratorStateVariant objects.
Status InitializeFromIterator(IteratorResource* iterator_resource) {
SerializationContext serialization_ctx({});
Status InitializeFromIterator(SerializationContext* serialization_ctx,
IteratorResource* iterator_resource) {
VariantTensorDataWriter writer;
TF_RETURN_IF_ERROR(iterator_resource->Save(&serialization_ctx, &writer));
TF_RETURN_IF_ERROR(iterator_resource->Save(serialization_ctx, &writer));
std::vector<std::unique_ptr<VariantTensorData>> data;
writer.ReleaseData(&data);
variants_.clear();
......@@ -1070,13 +1071,20 @@ namespace {
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
if (ctx->HasAttr(DatasetToGraphOp::kExternalStatePolicy)) {
int64 state_change_option;
OP_REQUIRES_OK(ctx, ctx->GetAttr(DatasetToGraphOp::kExternalStatePolicy,
&state_change_option));
external_state_policy_ =
SerializationContext::ExternalStatePolicy(state_change_option);
}
}
void Compute(OpKernelContext* ctx) override {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
......@@ -1084,13 +1092,21 @@ class SerializeIteratorOp : public OpKernel {
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource);
IteratorVariantSerializer serializer;
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(iterator_resource));
SerializationContext::Params params;
params.external_state_policy = external_state_policy_;
SerializationContext serialization_ctx(params);
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
iterator_resource));
Tensor* serialized_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
&serialized_t));
OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
}
private:
SerializationContext::ExternalStatePolicy external_state_policy_ =
SerializationContext::ExternalStatePolicy::kWarn;
};
class DeserializeIteratorOp : public OpKernel {
......
......@@ -695,6 +695,7 @@ REGISTER_OP("IteratorFromStringHandleV2")
REGISTER_OP("SerializeIterator")
.Input("resource_handle: resource")
.Attr("external_state_policy: int = 0")
.Output("serialized: variant")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(c->UnknownDim()));
......
......@@ -25,6 +25,7 @@ from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
......@@ -99,7 +100,7 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_external_state_policy = (
dataset_ops.ExternalStatePolicy.IGNORE)
distribute_options.ExternalStatePolicy.IGNORE)
dataset0 = dataset0.with_options(opt)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
......@@ -131,7 +132,7 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_external_state_policy = (
dataset_ops.ExternalStatePolicy.WARN)
distribute_options.ExternalStatePolicy.WARN)
dataset0 = dataset0.with_options(opt)
replicated_ds = distribute.replicate(dataset0,
[self._device1, self._device2])
......@@ -163,7 +164,7 @@ class LocalReplicateTest(test_base.DatasetTestBase, parameterized.TestCase):
dtype=dtypes.float32))
opt = dataset_ops.Options()
opt.experimental_external_state_policy = (
dataset_ops.ExternalStatePolicy.FAIL)
distribute_options.ExternalStatePolicy.FAIL)
dataset0 = dataset0.with_options(opt)
with self.assertRaises(errors.FailedPreconditionError):
replicated_ds = distribute.replicate(dataset0,
......
......@@ -36,6 +36,12 @@ class AutoShardPolicy(enum.IntEnum):
DATA = 2
class ExternalStatePolicy(enum.Enum):
WARN = 0
IGNORE = 1
FAIL = 2
@tf_export("data.experimental.DistributeOptions")
class DistributeOptions(options.OptionsBase):
"""Represents options for distributed data processing.
......
......@@ -17,6 +17,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import basic_session_run_hooks
......@@ -26,18 +28,39 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.util.tf_export import tf_export
def _convert_external_state_policy_to_enum(external_state_policy):
if external_state_policy == "warn":
return distribute_options.ExternalStatePolicy.WARN
if external_state_policy == "ignore":
return distribute_options.ExternalStatePolicy.IGNORE
if external_state_policy == "fail":
return distribute_options.ExternalStatePolicy.FAIL
raise ValueError(
"Failed to convert {} to an instance of ExternalStatePolicy."
"Supported values include: 'warn', 'ignore' and 'fail'".format(
external_state_policy))
@tf_export("data.experimental.make_saveable_from_iterator")
def make_saveable_from_iterator(iterator):
def make_saveable_from_iterator(iterator, external_state_policy="fail"):
"""Returns a SaveableObject for saving/restoring iterator state using Saver.
Args:
iterator: Iterator.
external_state_policy: A string that identifies how to handle input
pipelines that depend on external state. Possible values are
'ignore': The external state is silently ignored.
'warn': The external state is ignored, logging a warning.
'fail': The operation fails upon encountering external state.
By default we set it to 'fail'.
Returns:
A SaveableObject for saving/restoring iterator state using Saver.
Raises:
ValueError: If iterator does not support checkpointing.
ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
'fail'.
For example:
......@@ -68,8 +91,11 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
return iterator_ops._IteratorSaveable(iterator._iterator_resource, # pylint: disable=protected-access
iterator._iterator_resource.name) # pylint: disable=protected-access
policy_enum = _convert_external_state_policy_to_enum(external_state_policy)
return iterator_ops._IteratorSaveable( # pylint: disable=protected-access
iterator._iterator_resource, # pylint: disable=protected-access
iterator._iterator_resource.name, # pylint: disable=protected-access
external_state_policy=policy_enum)
@tf_export("data.experimental.CheckpointInputPipelineHook")
......@@ -118,16 +144,45 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
collector when building the eval graph.
"""
def __init__(self, estimator):
def __init__(self, estimator, external_state_policy="fail"):
"""Initializes a `CheckpointInputPipelineHook`.
If the input pipeline depends on external state (e.g. seeds for
RandomUniform) beyond the input pipeline, this hook would be unable to
serialize and deserialize that state. If its acceptable to ignore that state
change the external_state_policy argument to 'warn' or 'ignore'. For e.g.
```python
est = tf.estimator.Estimator(model_fn)
while True:
est.train(
train_input_fn,
hooks=[tf.data.experimental.CheckpointInputPipelineHook(
est, external_state_policy='warn')],
steps=train_steps_per_eval)
# Note: We do not pass the hook here.
metrics = est.evaluate(eval_input_fn)
if should_stop_the_training(metrics):
break
```
Args:
estimator: Estimator.
external_state_policy: A string that identifies how to handle input
pipelines that depend on external state. Possible values are
'ignore': The external state is silently ignored.
'warn': The external state is ignored, logging a warning.
'fail': The operation fails upon encountering external state.
By default we set it to 'fail'.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of saver or scaffold should be set.
ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
'fail'.
"""
self._external_state_policy = _convert_external_state_policy_to_enum(
external_state_policy)
# `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
# of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
# Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
......@@ -172,7 +227,11 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
if (self._checkpoint_saver_hook._saver is None and
self._checkpoint_saver_hook._scaffold is None):
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
saveables = [
iterator_ops._IteratorSaveable(
i, i.name, external_state_policy=self._external_state_policy)
for i in iterators
]
self._checkpoint_saver_hook._saver = _CustomSaver(
saveables, self._latest_filename, sharded=True)
# pylint: enable=protected-access
......
......@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import optional_ops
......@@ -58,8 +59,9 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(10).map(
lambda _: random_ops.random_uniform(()))
with self.assertRaises(errors.FailedPreconditionError):
self.evaluate(dataset._as_serialized_graph(
external_state_policy=dataset_ops.ExternalStatePolicy.FAIL))
self.evaluate(
dataset._as_serialized_graph(external_state_policy=distribute_options
.ExternalStatePolicy.FAIL))
@combinations.generate(test_base.default_test_combinations())
def testAsFunctionWithMap(self):
......
......@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
import abc
import enum
import functools
import sys
import threading
......@@ -94,12 +93,6 @@ AUTOTUNE = -1
tf_export("data.experimental.AUTOTUNE").export_constant(__name__, "AUTOTUNE")
class ExternalStatePolicy(enum.Enum):
WARN = 0
IGNORE = 1
FAIL = 2
@tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta)
class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
......@@ -210,10 +203,11 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
@deprecation.deprecated_args(None, "Use external_state_policy instead",
"allow_stateful")
def _as_serialized_graph(self,
allow_stateful=None,
strip_device_assignment=None,
external_state_policy=ExternalStatePolicy.WARN):
def _as_serialized_graph(
self,
allow_stateful=None,
strip_device_assignment=None,
external_state_policy=distribute_options.ExternalStatePolicy.WARN):
"""Produces serialized graph representation of the dataset.
Args:
......@@ -2660,7 +2654,7 @@ class Options(options_lib.OptionsBase):
experimental_external_state_policy = options_lib.create_option(
name="experimental_external_state_policy",
ty=ExternalStatePolicy,
ty=distribute_options.ExternalStatePolicy,
docstring="By default, tf.data will refuse to serialize a dataset or "
"checkpoint its iterator if the dataset contains a stateful op as the "
"serialization / checkpointing won't be able to capture its state. "
......@@ -2669,7 +2663,7 @@ class Options(options_lib.OptionsBase):
"in these ops. There are three settings available - IGNORE: in which we"
"completely ignore any state; WARN: We warn the user that some state "
"might be thrown away; FAIL: We fail if any state is being captured.",
default_factory=lambda: ExternalStatePolicy.WARN)
default_factory=lambda: distribute_options.ExternalStatePolicy.WARN)
def _graph_rewrites(self):
"""Produces the list of enabled static graph rewrites."""
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
import threading
import warnings
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
......@@ -793,8 +794,13 @@ class IteratorSpec(type_spec.TypeSpec):
class _IteratorSaveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource, name):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
def __init__(
self,
iterator_resource,
name,
external_state_policy=distribute_options.ExternalStatePolicy.FAIL):
serialized_iterator = gen_dataset_ops.serialize_iterator(
iterator_resource, external_state_policy=external_state_policy.value)
specs = [
BaseSaverBuilder.SaveSpec(
serialized_iterator,
......
......@@ -399,7 +399,7 @@ class TensorLikeDataAdapter(DataAdapter):
if self._shuffle:
# See b/141490660 for more details.
options.experimental_external_state_policy = (
dataset_ops.ExternalStatePolicy.IGNORE)
distribute_options.ExternalStatePolicy.IGNORE)
dataset = dataset.with_options(options)
return dataset
......
......@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'estimator\', \'external_state_policy\'], varargs=None, keywords=None, defaults=[\'fail\'], "
}
member_method {
name: "after_create_session"
......
......@@ -174,7 +174,7 @@ tf_module {
}
member_method {
name: "make_saveable_from_iterator"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'iterator\', \'external_state_policy\'], varargs=None, keywords=None, defaults=[\'fail\'], "
}
member_method {
name: "map_and_batch"
......
......@@ -3730,7 +3730,7 @@ tf_module {
}
member_method {
name: "SerializeIterator"
argspec: "args=[\'resource_handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'resource_handle\', \'external_state_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
}
member_method {
name: "SerializeManySparse"
......
......@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'estimator\', \'external_state_policy\'], varargs=None, keywords=None, defaults=[\'fail\'], "
}
member_method {
name: "after_create_session"
......
......@@ -146,7 +146,7 @@ tf_module {
}
member_method {
name: "make_saveable_from_iterator"
argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'iterator\', \'external_state_policy\'], varargs=None, keywords=None, defaults=[\'fail\'], "
}
member_method {
name: "map_and_batch"
......
......@@ -3730,7 +3730,7 @@ tf_module {
}
member_method {
name: "SerializeIterator"
argspec: "args=[\'resource_handle\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'resource_handle\', \'external_state_policy\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
}
member_method {
name: "SerializeManySparse"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册