提交 84f93b87 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Nesting `tf.data.Options()` optimization options under `experimental_optimization`.

PiperOrigin-RevId: 223417385
上级 18944dd7
......@@ -25,6 +25,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
@@OptimizationOptions
@@Optional
@@RandomDataset
@@Reducer
......@@ -86,10 +87,8 @@ from tensorflow.python.data.experimental.ops.interleave_ops import parallel_inte
from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
# Optimization constant that can be used to enable auto-tuning.
from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device
......
......@@ -42,6 +42,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
......@@ -68,6 +69,7 @@ py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
......@@ -127,6 +129,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
......@@ -148,6 +151,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
......@@ -167,6 +171,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
......@@ -192,6 +197,7 @@ py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
......@@ -227,6 +233,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
......@@ -272,6 +279,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
......@@ -313,6 +321,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -71,7 +72,8 @@ class FilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_filter_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.filter_fusion = True
dataset = dataset.with_options(options)
expected_output = []
for x in range(5):
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
......@@ -91,7 +92,8 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
options = dataset_ops.Options()
options.experimental_hoist_random_uniform = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.hoist_random_uniform = True
dataset = dataset.with_options(options)
self._testDataset(dataset)
......@@ -107,7 +109,8 @@ class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
options = dataset_ops.Options()
options.experimental_hoist_random_uniform = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.hoist_random_uniform = True
dataset = dataset.with_options(options)
self._testDataset(dataset)
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
......@@ -32,7 +33,8 @@ class MapAndBatchFusionTest(test_base.DatasetTestBase):
optimization.assert_next(
["MapAndBatch"])).map(lambda x: x * x).batch(10)
options = dataset_ops.Options()
options.experimental_map_and_batch_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_and_batch_fusion = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(
dataset, expected_output=[[x * x for x in range(10)]])
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -83,7 +84,8 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization.assert_next(
["Map", "FilterByLastComponent"])).map(function).filter(predicate)
options = dataset_ops.Options()
options.experimental_map_and_filter_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
......@@ -101,7 +103,8 @@ class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization.assert_next(["Map",
"Filter"])).map(function).filter(predicate)
options = dataset_ops.Options()
options.experimental_map_and_filter_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_and_filter_fusion = True
dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
......@@ -74,7 +75,8 @@ class MapFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_map_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_fusion = True
dataset = dataset.with_options(options)
expected_output = []
for x in range(5):
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
......@@ -67,7 +68,8 @@ class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(5).apply(
optimization.assert_next(next_nodes)).map(function)
options = dataset_ops.Options()
options.experimental_map_parallelization = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
if should_optimize:
self.assertDatasetProduces(
......
......@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -353,7 +354,8 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
optimized = _make_dataset(["Batch", map_node_name]
if expect_optimized else [map_node_name, "Batch"])
options = dataset_ops.Options()
options.experimental_map_vectorization = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.map_vectorization = True
optimized = optimized.with_options(options)
return unoptimized, optimized
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
......@@ -42,7 +43,8 @@ class NoopEliminationTest(test_base.DatasetTestBase):
dataset = dataset.repeat(some_tensor).skip(5).take(-1).skip(0).repeat(
1).prefetch(0).prefetch(1).cache()
options = dataset_ops.Options()
options.experimental_noop_elimination = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.noop_elimination = True
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=range(5))
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
......@@ -32,7 +33,8 @@ class ShuffleAndRepeatFusionTest(test_base.DatasetTestBase):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(["ShuffleAndRepeat"])).shuffle(10).repeat(2)
options = dataset_ops.Options()
options.experimental_shuffle_and_repeat_fusion = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.shuffle_and_repeat_fusion = True
dataset = dataset.with_options(options)
get_next = self.getNext(dataset)
......
......@@ -235,6 +235,16 @@ py_library(
],
)
py_library(
name = "optimization_options",
srcs = ["optimization_options.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python/data/util:options",
],
)
py_library(
name = "parsing_ops",
srcs = ["parsing_ops.py"],
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for controlling optimizations in `tf.data` pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.OptimizationOptions")
class OptimizationOptions(options.OptionsBase):
"""Represents options for dataset optimizations.
You can apply `OptimizationOptions` to a `dataset` object, as follows:
```python
options = tf.data.Options()
options.optimization = tf.data.experimental.OptimizationOptions()
options.optimization.map_and_batch_fusion = True
dataset = dataset.with_options(options)
```
"""
filter_fusion = options.create_option(
name="filter_fusion",
ty=bool,
docstring="Whether to fuse filter transformations.")
hoist_random_uniform = options.create_option(
name="hoist_random_uniform",
ty=bool,
docstring=
"Whether to hoist `tf.random_uniform()` ops out of map transformations.")
map_and_batch_fusion = options.create_option(
name="map_and_batch_fusion",
ty=bool,
docstring="Whether to fuse map and batch transformations.")
map_and_filter_fusion = options.create_option(
name="map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map and filter transformations.")
map_fusion = options.create_option(
name="map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map transformations.")
map_parallelization = options.create_option(
name="map_parallelization",
ty=bool,
docstring="Whether to parallelize stateless map transformations.")
map_vectorization = options.create_option(
name="map_vectorization",
ty=bool,
docstring="Whether to vectorize map transformations.")
noop_elimination = options.create_option(
name="noop_elimination",
ty=bool,
docstring="Whether to eliminate no-op transformations.")
shuffle_and_repeat_fusion = options.create_option(
name="shuffle_and_repeat_fusion",
ty=bool,
docstring="Whether to fuse shuffle and repeat transformations.")
......@@ -34,7 +34,7 @@ class StatsOptions(options.OptionsBase):
```python
aggretator = tf.data.experimental.StatsAggregator()
options = dataset_ops.Options()
options = tf.data.Options()
options.experimental_stats = tf.data.experimental.StatsOptions()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
......
......@@ -29,7 +29,7 @@ class ThreadingOptions(options.OptionsBase):
To apply `ThreadingOptions` to a `dataset` object, use the following pattern:
```python
options = dataset_ops.Options()
options = tf.data.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = 10
dataset = dataset.with_options(options)
......
......@@ -408,6 +408,7 @@ cuda_py_test(
"//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
......
......@@ -227,12 +227,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_filter_fusion = False
options2.experimental_deterministic = False
ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
options2)
self.assertTrue(ds.options().experimental_autotune)
# Explicitly check that flag is False since assertFalse allows None
self.assertIs(ds.options().experimental_filter_fusion, False)
self.assertIs(ds.options().experimental_deterministic, False)
def testOptionsTwiceDifferentError(self):
options1 = dataset_ops.Options()
......@@ -247,12 +247,12 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
options1 = dataset_ops.Options()
options1.experimental_autotune = True
options2 = dataset_ops.Options()
options2.experimental_filter_fusion = True
options2.experimental_deterministic = True
ds = dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(0).with_options(options1),
dataset_ops.Dataset.range(0).with_options(options2)))
self.assertTrue(ds.options().experimental_autotune)
self.assertTrue(ds.options().experimental_filter_fusion)
self.assertTrue(ds.options().experimental_deterministic)
# TODO(b/119882922): use-after-free bug in eager mode.
# pylint: disable=g-long-lambda
......
......@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
......@@ -265,7 +266,8 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
dataset = dataset.cache()
options = dataset_ops.Options()
options.experimental_noop_elimination = True
options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.noop_elimination = True
dataset = dataset.with_options(options)
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
......
......@@ -27,6 +27,7 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:filter_for_shard_ops",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python/data/util:nest",
......
......@@ -27,6 +27,7 @@ import six
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import filter_for_shard_ops
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.ops import iterator_ops
......@@ -1587,56 +1588,15 @@ class Options(options_lib.OptionsBase):
"Whether to dynamically adjust the values of tunable parameters (e.g. "
"degrees of parallelism).")
experimental_filter_fusion = options_lib.create_option(
name="experimental_filter_fusion",
ty=bool,
docstring="Whether to fuse filter transformations.")
experimental_hoist_random_uniform = options_lib.create_option(
name="experimental_hoist_random_uniform",
ty=bool,
docstring=
"Whether to hoist `tf.random_uniform()` ops out of map transformations.")
experimental_map_and_batch_fusion = options_lib.create_option(
name="experimental_map_and_batch_fusion",
ty=bool,
docstring="Whether to fuse map and batch transformations.")
experimental_map_and_filter_fusion = options_lib.create_option(
name="experimental_map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map and filter transformations.")
experimental_map_fusion = options_lib.create_option(
name="experimental_map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map transformations.")
experimental_map_parallelization = options_lib.create_option(
name="experimental_map_parallelization",
ty=bool,
docstring="Whether to parallelize stateless map transformations.")
experimental_map_vectorization = options_lib.create_option(
name="experimental_map_vectorization",
ty=bool,
docstring="Whether to vectorize map transformations.")
experimental_noop_elimination = options_lib.create_option(
name="experimental_noop_elimination",
ty=bool,
docstring="Whether to eliminate no-op transformations.")
experimental_numa_aware = options_lib.create_option(
name="experimental_numa_aware",
ty=bool,
docstring="Whether to use NUMA-aware operations.")
experimental_shuffle_and_repeat_fusion = options_lib.create_option(
name="experimental_shuffle_and_repeat_fusion",
ty=bool,
docstring="Whether to fuse shuffle and repeat transformations.")
experimental_optimization = options_lib.create_option(
name="experimental_optimization",
ty=optimization_options.OptimizationOptions,
docstring="Associates the given optimization options with the dataset.")
experimental_stats = options_lib.create_option(
name="experimental_stats",
......@@ -1650,29 +1610,30 @@ class Options(options_lib.OptionsBase):
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
experimental_optimizations = [
"filter_fusion",
"hoist_random_uniform",
"map_and_batch_fusion",
"map_and_filter_fusion",
"map_fusion",
"map_parallelization",
"map_vectorization",
"noop_elimination",
"shuffle_and_repeat_fusion",
]
result = []
for exp_opt in experimental_optimizations:
if getattr(self, "experimental_" + exp_opt):
result.append(exp_opt)
if getattr(self, "experimental_numa_aware"):
result = []
exp_optimization_options = self.experimental_optimization
if exp_optimization_options:
optimizations = [
"filter_fusion",
"hoist_random_uniform",
"map_and_batch_fusion",
"map_and_filter_fusion",
"map_fusion",
"map_parallelization",
"map_vectorization",
"noop_elimination",
"shuffle_and_repeat_fusion",
]
for optimization in optimizations:
if getattr(exp_optimization_options, optimization):
result.append(optimization)
if self.experimental_numa_aware:
result.append("make_numa_aware")
if getattr(self, "experimental_deterministic") is False:
if self.experimental_deterministic is False:
result.append("make_sloppy")
experimental_stats_options = getattr(self, "experimental_stats")
if experimental_stats_options and getattr(experimental_stats_options,
"latency_all_edges"):
exp_stats_options = self.experimental_stats
if exp_stats_options and exp_stats_options.latency_all_edges:
result.append("latency_all_edges")
return result
......
......@@ -11,44 +11,12 @@ tf_class {
name: "experimental_deterministic"
mtype: "<type \'property\'>"
}
member {
name: "experimental_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_batch_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_parallelization"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_vectorization"
mtype: "<type \'property\'>"
}
member {
name: "experimental_noop_elimination"
mtype: "<type \'property\'>"
}
member {
name: "experimental_numa_aware"
mtype: "<type \'property\'>"
}
member {
name: "experimental_shuffle_and_repeat_fusion"
name: "experimental_optimization"
mtype: "<type \'property\'>"
}
member {
......
path: "tensorflow.data.experimental.OptimizationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.OptimizationOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "map_and_batch_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_and_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_parallelization"
mtype: "<type \'property\'>"
}
member {
name: "map_vectorization"
mtype: "<type \'property\'>"
}
member {
name: "noop_elimination"
mtype: "<type \'property\'>"
}
member {
name: "shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -12,6 +12,10 @@ tf_module {
name: "CsvDataset"
mtype: "<type \'type\'>"
}
member {
name: "OptimizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "Optional"
mtype: "<type \'type\'>"
......
......@@ -11,44 +11,12 @@ tf_class {
name: "experimental_deterministic"
mtype: "<type \'property\'>"
}
member {
name: "experimental_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_batch_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_parallelization"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_vectorization"
mtype: "<type \'property\'>"
}
member {
name: "experimental_noop_elimination"
mtype: "<type \'property\'>"
}
member {
name: "experimental_numa_aware"
mtype: "<type \'property\'>"
}
member {
name: "experimental_shuffle_and_repeat_fusion"
name: "experimental_optimization"
mtype: "<type \'property\'>"
}
member {
......
path: "tensorflow.data.experimental.OptimizationOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.optimization_options.OptimizationOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "map_and_batch_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_and_filter_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_fusion"
mtype: "<type \'property\'>"
}
member {
name: "map_parallelization"
mtype: "<type \'property\'>"
}
member {
name: "map_vectorization"
mtype: "<type \'property\'>"
}
member {
name: "noop_elimination"
mtype: "<type \'property\'>"
}
member {
name: "shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -12,6 +12,10 @@ tf_module {
name: "CsvDataset"
mtype: "<type \'type\'>"
}
member {
name: "OptimizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "Optional"
mtype: "<type \'type\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册