提交 bd1aaf9f 编写于 作者: R Rachel Lim 提交者: TensorFlower Gardener

[tf.data] Adds default values for `tf.data.Options`'s...

[tf.data] Adds default values for `tf.data.Options`'s `experimental_optimization`, `experimental_stats`, and `experimental_threading` property. Changes the default `latency_all_edges` option on `StatsOptions`. In order to turn on latency statistics, a user now has to explicitly specify `options.experimental_stats.latency_all_edges = True`.

PiperOrigin-RevId: 225411510
上级 c4d9c9b0
......@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -72,7 +71,6 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.filter_fusion = True
dataset = dataset.with_options(options)
expected_output = []
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
......@@ -92,7 +91,6 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.hoist_random_uniform = True
dataset = dataset.with_options(options)
self._testDataset(dataset)
......@@ -109,7 +107,6 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.hoist_random_uniform = True
dataset = dataset.with_options(options)
self._testDataset(dataset)
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
......@@ -36,7 +35,6 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.latency_all_edges = True
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
......@@ -53,29 +51,6 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str,
"record_latency_PrefetchDataset/_6", 1)
def testLatencyStatsOptimizationV2(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
self.assertDatasetProduces(
dataset,
expected_output=[1],
requires_initialization=True,
num_test_iterations=1)
summary_t = aggregator.get_summary()
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "record_latency_TensorDataset/_1",
1)
self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", 1)
self._assertSummaryHasCount(summary_str,
"record_latency_PrefetchDataset/_6", 1)
if __name__ == "__main__":
test.main()
......@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -84,7 +83,6 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization.assert_next(
["Map", "FilterByLastComponent"])).map(function).filter(predicate)
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
......@@ -103,7 +101,6 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization.assert_next(["Map",
"Filter"])).map(function).filter(predicate)
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
......@@ -75,7 +74,6 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_fusion = True
dataset = dataset.with_options(options)
expected_output = []
......
......@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
......@@ -68,7 +67,6 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(next_nodes)).map(function)
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
if should_optimize:
......
......@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -350,9 +349,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.map(map_fn, num_parallel_calls)
dataset = dataset.batch(100)
options = dataset_ops.Options()
opt_options = optimization_options.OptimizationOptions()
opt_options.map_and_batch_fusion = False
options.experimental_optimization = opt_options
options.experimental_optimization.map_and_batch_fusion = False
dataset = dataset.with_options(options)
return dataset
......@@ -360,9 +357,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = _make_dataset(["Batch", map_node_name]
if expect_optimized else [map_node_name, "Batch"])
options = dataset_ops.Options()
opt_options = optimization_options.OptimizationOptions()
opt_options.map_vectorization = True
options.experimental_optimization = opt_options
options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options)
return unoptimized, optimized
......
......@@ -25,7 +25,6 @@ import numpy as np
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.kernel_tests import test_base
......@@ -168,9 +167,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
# here because of a bug with chaining _OptimizeDatasets when there are
# nested dataset functions
options = dataset_ops.Options()
opt_options = optimization_options.OptimizationOptions()
opt_options.map_and_batch_fusion = True
options.experimental_optimization = opt_options
options.experimental_optimization.map_and_batch_fusion = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[[0]])
......@@ -217,10 +214,8 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
unoptimized_dataset = dataset_fn(variable)
options = dataset_ops.Options()
opt_options = optimization_options.OptimizationOptions()
opt_options.noop_elimination = True
opt_options.map_and_batch_fusion = True
options.experimental_optimization = opt_options
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_and_batch_fusion = True
optimized_dataset = unoptimized_dataset.with_options(options)
# Check that warning is logged.
......@@ -233,7 +228,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
"tf.Variable. The following optimizations will be disabled: %s."
" To enable optimizations, use resource variables instead by "
"calling `tf.enable_resource_variables()` at the start of the "
"program." % (", ".join(opt_options._static_optimizations())))
"program." % (", ".join(options._static_optimizations())))
self.assertTrue(any([expected in str(warning) for warning in w]))
# Check that outputs are the same in the optimized and unoptimized cases,
......@@ -271,10 +266,8 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
only explicitly enabled optimizations will be applied.
"""
options = dataset_ops.Options()
opt_options = optimization_options.OptimizationOptions()
opt_options.hoist_random_uniform = True
opt_options.apply_default_optimizations = False
options.experimental_optimization = opt_options
options.experimental_optimization.hoist_random_uniform = True
options.experimental_optimization.apply_default_optimizations = False
expected_optimizations = ["hoist_random_uniform"]
self.assertEqual(options._static_optimizations(), expected_optimizations)
......
......@@ -26,7 +26,6 @@ from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
......@@ -47,7 +46,6 @@ def function_set_stats_aggregator(dataset,
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.aggregator = aggregator
options.experimental_stats.prefix = prefix
options.experimental_stats.counter_prefix = counter_prefix
......
......@@ -26,12 +26,14 @@ from tensorflow.python.util.tf_export import tf_export
class OptimizationOptions(options.OptionsBase):
"""Represents options for dataset optimizations.
You can apply `OptimizationOptions` to a `dataset` object, as follows:
You can set the optimization options of a dataset through the
`experimental_optimization` property of `tf.data.Options`; the property is
an instance of `tf.data.experimental.OptimizationOptions`.
```python
options = tf.data.Options()
options.optimization = tf.data.experimental.OptimizationOptions()
options.optimization.map_and_batch_fusion = True
options.experimental_optimization.map_vectorization = True
options.apply_default_optimizations = False
dataset = dataset.with_options(options)
```
"""
......
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import function
......@@ -72,9 +71,7 @@ def copy_to_device(target_device, source_device="/cpu:0"):
def _apply_fn(dataset):
options = dataset_ops.Options()
options.experimental_autotune = False
opt_options = optimization_options.OptimizationOptions()
opt_options.apply_default_optimizations = False
options.experimental_optimization = opt_options
options.experimental_optimization.apply_default_optimizations = False
return _CopyToDeviceDataset(
dataset, target_device=target_device,
source_device=source_device).with_options(options)
......
......@@ -28,27 +28,19 @@ from tensorflow.python.util.tf_export import tf_export
class StatsOptions(options.OptionsBase):
"""Represents options for collecting dataset stats using `StatsAggregator`.
To apply `StatsOptions` with a `tf.data.Dataset` object, use the following
pattern:
You can set the stats options of a dataset through the `experimental_stats`
property of `tf.data.Options`; the property is an instance of
`tf.data.experimental.StatsOptions`. For example, to collect latency stats
on all dataset edges, use the following pattern:
```python
aggregator = tf.data.experimental.StatsAggregator()
options = tf.data.Options()
options.experimental_stats = tf.data.experimental.StatsOptions()
options.experimental_stats.aggregator = aggregator
options.experimental_stats.latency_all_edges = True
dataset = dataset.with_options(options)
```
Note: a `StatsAggregator` object can be attached either duing construction or
can be provided later like in above example.
```python
aggretator = tf.data.experimental.StatsAggregator()
# attach aggregator during construction
options.experimental_stats = tf.data.experimental.StatsOptions(aggregator)
.....
```
"""
aggregator = options.create_option(
......@@ -62,18 +54,15 @@ class StatsOptions(options.OptionsBase):
ty=str,
docstring=
"Prefix to prepend all statistics recorded for the input `dataset` with.",
default="")
default_factory=lambda: "")
counter_prefix = options.create_option(
name="counter_prefix",
ty=str,
docstring=
"Prefix for the statistics recorded as counter.",
default="")
docstring="Prefix for the statistics recorded as counter.",
default_factory=lambda: "")
latency_all_edges = options.create_option(
name="latency_all_edges",
ty=bool,
docstring=
"Whether to add latency measurements on all edges.",
default=True)
docstring="Whether to add latency measurements on all edges.")
......@@ -26,11 +26,12 @@ from tensorflow.python.util.tf_export import tf_export
class ThreadingOptions(options.OptionsBase):
"""Represents options for dataset threading.
To apply `ThreadingOptions` to a `dataset` object, use the following pattern:
You can set the threading options of a dataset through the
`experimental_threading` property of `tf.data.Options`; the property is
an instance of `tf.data.experimental.ThreadingOptions`.
```python
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = 10
dataset = dataset.with_options(options)
```
......@@ -46,5 +47,4 @@ class ThreadingOptions(options.OptionsBase):
name="private_threadpool_size",
ty=int,
docstring=
"If set, the dataset will use a private threadpool of the given size.",
default=None)
"If set, the dataset will use a private threadpool of the given size.")
......@@ -444,6 +444,19 @@ cuda_py_test(
],
)
tf_py_test(
name = "options_test",
size = "small",
srcs = ["options_test.py"],
additional_deps = [
":test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python:client_testlib",
],
)
tf_py_test(
name = "padded_batch_test",
size = "small",
......
......@@ -207,53 +207,6 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(2, inputs.count(ds2))
self.assertEqual(1, inputs.count(ds3))
def testOptionsDefault(self):
ds = dataset_ops.Dataset.range(0)
self.assertEqual(dataset_ops.Options(), ds.options())
def testOptionsOnce(self):
options = dataset_ops.Options()
ds = dataset_ops.Dataset.range(0).with_options(options).cache()
self.assertEqual(options, ds.options())
def testOptionsTwiceSame(self):
options = dataset_ops.Options()
options.experimental_autotune = True
ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
options)
self.assertEqual(options, ds.options())
def testOptionsTwiceDifferent(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = False
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
options2)
self.assertTrue(ds.options().experimental_autotune)
# Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_deterministic, False)
def testOptionsTwiceDifferentError(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_autotune = False
with self.assertRaisesRegexp(ValueError,
"Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
def testOptionsMergeOptionsFromMultipleInputs(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = True
ds = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(0).with_options(options1),
dataset_ops.Dataset.range(0).with_options(options2)))
self.assertTrue(ds.options().experimental_autotune)
self.assertTrue(ds.options().experimental_deterministic)
# TODO(b/119882922): use-after-free bug in eager mode.
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
......@@ -313,5 +266,6 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
round_trip_dataset, [self.evaluate(tf_value_fn())],
requires_initialization=True)
if __name__ == "__main__":
test.main()
......@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
......@@ -275,7 +274,6 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.noop_elimination = True
dataset = dataset.with_options(options)
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Options`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
class OptionsTest(test_base.DatasetTestBase):
def testOptionsDefault(self):
ds = dataset_ops.Dataset.range(0)
self.assertEqual(dataset_ops.Options(), ds.options())
def testOptionsOnce(self):
options = dataset_ops.Options()
ds = dataset_ops.Dataset.range(0).with_options(options).cache()
self.assertEqual(options, ds.options())
def testOptionsTwiceSame(self):
options = dataset_ops.Options()
options.experimental_autotune = True
ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
options)
self.assertEqual(options, ds.options())
def testOptionsTwiceDifferent(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = False
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
options2)
self.assertTrue(ds.options().experimental_autotune)
# Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_deterministic, False)
def testOptionsTwiceDifferentError(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_autotune = False
with self.assertRaisesRegexp(ValueError,
"Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
def testOptionsMergeOptionsFromMultipleInputs(self):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_deterministic = True
ds = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(0).with_options(options1),
dataset_ops.Dataset.range(0).with_options(options2)))
self.assertTrue(ds.options().experimental_autotune)
self.assertTrue(ds.options().experimental_deterministic)
def testOptionsHaveDefaults(self):
options1 = dataset_ops.Options()
options2 = dataset_ops.Options()
self.assertIsNot(options1.experimental_optimization,
options2.experimental_optimization)
self.assertIsNot(options1.experimental_stats,
options2.experimental_stats)
self.assertIsNot(options1.experimental_threading,
options2.experimental_threading)
self.assertEquals(options1.experimental_optimization,
optimization_options.OptimizationOptions())
self.assertEquals(options1.experimental_stats,
stats_options.StatsOptions())
self.assertEquals(options1.experimental_threading,
threading_options.ThreadingOptions())
if __name__ == "__main__":
test.main()
......@@ -1712,26 +1712,26 @@ class Options(options_lib.OptionsBase):
experimental_optimization = options_lib.create_option(
name="experimental_optimization",
ty=optimization_options.OptimizationOptions,
docstring="Associates the given optimization options with the dataset.")
docstring="Associates the given optimization options with the dataset.",
default_factory=optimization_options.OptimizationOptions)
experimental_stats = options_lib.create_option(
name="experimental_stats",
ty=stats_options.StatsOptions,
docstring="Associates the given statistics options with the dataset.")
docstring="Associates the given statistics options with the dataset.",
default_factory=stats_options.StatsOptions)
experimental_threading = options_lib.create_option(
name="experimental_threading",
ty=threading_options.ThreadingOptions,
docstring="Associates the given threading options with the dataset.")
docstring="Associates the given threading options with the dataset.",
default_factory=threading_options.ThreadingOptions)
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
result = []
exp_optimization_options = (
self.experimental_optimization or
optimization_options.OptimizationOptions()) # If not set, use default
result.extend(exp_optimization_options._static_optimizations()) # pylint: disable=protected-access
result.extend(self.experimental_optimization._static_optimizations()) # pylint: disable=protected-access
if self.experimental_numa_aware:
result.append("make_numa_aware")
......
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
......@@ -197,9 +196,7 @@ class MultiDeviceIterator(object):
# non-CPU devices.
options = dataset_ops.Options()
options.experimental_autotune = False
opt_options = optimization_options.OptimizationOptions()
opt_options.apply_default_optimizations = False
options.experimental_optimization = opt_options
options.experimental_optimization.apply_default_optimizations = False
ds = ds.with_options(options)
with ops.device(device):
self._device_iterators.append(ds.make_initializable_iterator())
......
......@@ -48,27 +48,32 @@ class OptionsBase(object):
return NotImplemented
def create_option(name, ty, docstring, default=None):
def create_option(name, ty, docstring, default_factory=lambda: None):
"""Creates a type-checked property.
Args:
name: the name to use
ty: the type to use
docstring: the docstring to use
default: the default value to use
name: The name to use.
ty: The type to use. The type of the property will be validated when it
is set.
docstring: The docstring to use.
default_factory: A callable that takes no arguments and returns a default
value to use if not set.
Returns:
A type-checked property.
"""
def get_fn(self):
return self._options.get(name, default) # pylint: disable=protected-access
def get_fn(option):
# pylint: disable=protected-access
if name not in option._options:
option._options[name] = default_factory()
return option._options.get(name)
def set_fn(self, value):
def set_fn(option, value):
if not isinstance(value, ty):
raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
(name, ty, value, type(value)))
self._options[name] = value # pylint: disable=protected-access
option._options[name] = value # pylint: disable=protected-access
return property(get_fn, set_fn, None, docstring)
......
......@@ -24,9 +24,12 @@ from tensorflow.python.platform import test
class _TestOptions(options.OptionsBase):
x = options.create_option(
name="x", ty=int, docstring="the answer to everything", default=42)
name="x",
ty=int,
docstring="the answer to everything",
default_factory=lambda: 42)
y = options.create_option(
name="y", ty=float, docstring="a tasty pie", default=3.14)
name="y", ty=float, docstring="a tasty pie", default_factory=lambda: 3.14)
class _NestedTestOptions(options.OptionsBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册