提交 a7b4bd66 编写于 作者: S Shivani Agrawal 提交者: TensorFlower Gardener

[data-stats] Adds option "experimental_stats" to `tf.data.Options` which takes...

[data-stats] Adds option "experimental_stats" to `tf.data.Options` which takes `tf.data.experimental.StatsOptions` object. `StatsOptions` can configure options for collecting `dataset` stats using `StatsAggregator`, and it has aggregator as an argument which attaches the given aggregator to the dataset. (this will also replace `set_stats_aggregator()` dataset transformation.)

PiperOrigin-RevId: 220230269
上级 fa0661ee
......@@ -29,6 +29,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@RandomDataset
@@Reducer
@@SqlDataset
@@StatsAggregator
@@StatsOptions
@@TFRecordWriter
@@bucket_by_sequence_length
......@@ -52,9 +54,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@rejection_resample
@@sample_from_datasets
@@scan
@@set_stats_aggregator
@@shuffle_and_repeat
@@StatsAggregator
@@unbatch
@@unique
......@@ -98,9 +98,9 @@ from tensorflow.python.data.experimental.ops.readers import SqlDataset
from tensorflow.python.data.experimental.ops.resampling import rejection_resample
from tensorflow.python.data.experimental.ops.scan_ops import scan
from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator
from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator
from tensorflow.python.data.experimental.ops.stats_options import StatsOptions
from tensorflow.python.data.experimental.ops.unique import unique
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
......
......@@ -625,9 +625,15 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:stats_aggregator",
"//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
......
......@@ -89,6 +89,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/experimental/ops:stats_aggregator",
"//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
......
......@@ -19,7 +19,8 @@ from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
......@@ -28,18 +29,45 @@ from tensorflow.python.platform import test
class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
def testLatencyStatsOptimization(self):
stats_aggregator = stats_ops.StatsAggregator()
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
options.experimental_latency_all_edges = True
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.latency_all_edges = True
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertEqual(1 * 1, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str,
"record_latency_TensorDataset/_1", 1)
self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
1)
self._assertSummaryHasCount(summary_str,
"record_latency_PrefetchDataset/_6", 1)
def testLatencyStatsOptimizationV2(self):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions(aggregator)
dataset = dataset.with_options(options)
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......
......@@ -651,6 +651,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/experimental/ops:stats_aggregator",
"//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
......@@ -92,9 +93,9 @@ class StatsDatasetSerializationTest(
None, num_outputs)
def _build_dataset_stats_aggregator(self):
stats_aggregator = stats_ops.StatsAggregator()
aggregator = stats_aggregator.StatsAggregator()
return dataset_ops.Dataset.range(10).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.set_stats_aggregator(aggregator))
def test_set_stats_aggregator_not_support_checkpointing(self):
with self.assertRaisesRegexp(errors.UnimplementedError,
......
......@@ -17,13 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
......@@ -32,17 +35,43 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def function_set_stats_aggregator(dataset,
aggregator,
prefix="",
counter_prefix=""):
return dataset.apply(
stats_ops.set_stats_aggregator(aggregator, prefix, counter_prefix))
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions(aggregator)
options.experimental_stats.latency_all_edges = False
if prefix:
options.experimental_stats.prefix = prefix
if counter_prefix:
options.experimental_stats.counter_prefix = counter_prefix
return dataset.with_options(options)
@parameterized.named_parameters(
dict(
testcase_name="SetStatsAggregator",
dataset_transformation=function_set_stats_aggregator),
dict(
testcase_name="StatsOptions",
dataset_transformation=function_apply_options))
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
def testBytesProduced(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
stats_ops.bytes_produced_stats("bytes_produced")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.bytes_produced_stats("bytes_produced"))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -60,14 +89,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
def testLatencyStats(self):
stats_aggregator = stats_ops.StatsAggregator()
def testLatencyStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -79,14 +108,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
def testPrefetchBufferUtilization(self):
stats_aggregator = stats_ops.StatsAggregator()
def testPrefetchBufferUtilization(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
-1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1)
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -106,14 +135,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
def testPrefetchBufferScalars(self):
stats_aggregator = stats_ops.StatsAggregator()
def testPrefetchBufferScalars(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(10).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(0)
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -128,14 +157,14 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testFilteredElementsStats(self):
stats_aggregator = stats_ops.StatsAggregator()
def testFilteredElementsStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter(
lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
lambda x: math_ops.equal(math_ops.mod(x, 3), 0))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run(iterator.initializer)
......@@ -153,7 +182,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasScalarValue(
sess.run(summary_t), "Filter::filtered_elements", 34.0)
def testMapBufferUtilization(self):
def testMapBufferUtilization(self, dataset_transformation):
def dataset_fn():
return dataset_ops.Dataset.range(10).map(
......@@ -161,9 +190,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
num_parallel_calls=4)
self._testParallelCallsStats(
dataset_fn, "ParallelMap", 10, function_processing_time=True)
dataset_fn,
"ParallelMap",
10,
dataset_transformation,
function_processing_time=True)
def testMapAutoTuneBufferUtilization(self):
def testMapAutoTuneBufferUtilization(self, dataset_transformation):
def dataset_fn():
dataset = dataset_ops.Dataset.range(10).map(
......@@ -174,9 +207,13 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
return dataset.with_options(options)
self._testParallelCallsStats(
dataset_fn, "ParallelMap", 10, function_processing_time=True)
dataset_fn,
"ParallelMap",
10,
dataset_transformation,
function_processing_time=True)
def testInterleaveAutoTuneBufferUtilization(self):
def testInterleaveAutoTuneBufferUtilization(self, dataset_transformation):
def dataset_fn():
dataset = dataset_ops.Dataset.range(10).map(
......@@ -189,9 +226,10 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
options.experimental_autotune = True
return dataset.with_options(options)
self._testParallelCallsStats(dataset_fn, "ParallelInterleaveV2", 10)
self._testParallelCallsStats(dataset_fn, "ParallelInterleaveV2", 10,
dataset_transformation)
def testMapAndBatchAutoTuneBufferUtilization(self):
def testMapAndBatchAutoTuneBufferUtilization(self, dataset_transformation):
def dataset_fn():
dataset = dataset_ops.Dataset.range(100).apply(
......@@ -208,17 +246,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
dataset_fn,
"MapAndBatch",
num_output,
dataset_transformation,
check_elements=False,
function_processing_time=True)
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
def testReinitialize(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
for j in range(5):
......@@ -232,7 +271,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency", (j + 1) * 100.0)
def testNoAggregatorRegistered(self):
def testNoAggregatorRegistered(self, dataset_transformation):
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
iterator = dataset.make_initializable_iterator()
......@@ -245,15 +284,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testMultipleTags(self):
stats_aggregator = stats_ops.StatsAggregator()
def testMultipleTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency_2")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.latency_stats("record_latency_2"))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -269,15 +308,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(
sess.run(summary_t), "record_latency_2", 100.0)
def testRepeatedTags(self):
stats_aggregator = stats_ops.StatsAggregator()
def testRepeatedTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......@@ -289,15 +328,15 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self):
stats_aggregator = stats_ops.StatsAggregator()
def testMultipleIteratorsSameAggregator(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator))
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator)
iterator_0 = dataset.make_initializable_iterator()
iterator_1 = dataset.make_initializable_iterator()
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
......@@ -309,18 +348,18 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleDatasetWithTags(self):
stats_aggregator = stats_ops.StatsAggregator()
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator, "dataset1"))
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator, prefix="dataset1")
dataset2 = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
stats_ops.set_stats_aggregator(stats_aggregator, "dataset2"))
stats_ops.latency_stats("record_latency"))
dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2")
iterator_0 = dataset.make_initializable_iterator()
iterator_1 = dataset2.make_initializable_iterator()
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
......@@ -338,15 +377,22 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(summary_t), "dataset2_record_latency", 100.0)
@parameterized.named_parameters(
dict(
testcase_name="SetStatsAggregator",
dataset_transformation=function_set_stats_aggregator),
dict(
testcase_name="StatsOptions",
dataset_transformation=function_apply_options))
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testFeaturesStats(self):
def testFeaturesStats(self, dataset_transformation):
num_epochs = 5
total_records = num_epochs * self._num_records
batch_size = 2
stats_aggregator = stats_ops.StatsAggregator()
aggregator = stats_aggregator.StatsAggregator()
def dataset_fn():
return self.make_batch_feature(
......@@ -362,13 +408,17 @@ class FeatureStatsDatasetTest(
num_output = total_records // batch_size + 1
self._testParallelCallsStats(
dataset_fn, "ParseExample", num_output, check_elements=False)
dataset_fn,
"ParseExample",
num_output,
dataset_transformation,
check_elements=False)
iterator = dataset_fn().apply(
stats_ops.set_stats_aggregator(
stats_aggregator, "record_stats")).make_initializable_iterator()
dataset = dataset_transformation(
dataset_fn(), aggregator, prefix="record_stats")
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.test_session() as sess:
sess.run(iterator.initializer)
......
......@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import errors
......@@ -87,14 +87,15 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
dataset_fn,
dataset_name,
num_output,
dataset_transformation,
function_processing_time=False,
check_elements=True):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_fn().apply(
stats_ops.set_stats_aggregator(stats_aggregator))
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_fn()
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
sess.run(iterator.initializer)
......
......@@ -272,6 +272,16 @@ py_library(
],
)
py_library(
name = "stats_aggregator",
srcs = ["stats_aggregator.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:util",
],
)
py_library(
name = "stats_ops",
srcs = ["stats_ops.py"],
......@@ -287,6 +297,15 @@ py_library(
],
)
py_library(
name = "stats_options",
srcs = ["stats_options.py"],
srcs_version = "PY2AND3",
deps = [
":stats_aggregator",
],
)
py_library(
name = "threadpool",
srcs = ["threadpool.py"],
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StatsAggregator for aggregating statistics from `tf.data` pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsAggregator")
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
iterator (see below). For example, to record the latency of producing each
element by iterating over a dataset:
```python
dataset = ...
dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
```
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
aggregator = tf.data.experimental.StatsAggregator()
dataset = ...
# Apply `StatsOptions` to associate `dataset` with `aggregator`.
options = dataset_ops.Options()
options.experimental_stats = tf.data.experimental.StatsOptions(aggregator)
dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
```
To get a protocol buffer summary of the currently aggregated statistics,
use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
so that the summaries will be included with any existing summaries.
```python
aggregator = tf.data.experimental.StatsAggregator()
# ...
stats_summary = aggregator.get_summary()
tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
```
Note: This interface is experimental and expected to change. In particular,
we expect to add other implementations of `StatsAggregator` that provide
different ways of exporting statistics, and add more types of statistics.
"""
def __init__(self):
"""Creates a `StatsAggregator`."""
self._resource = gen_dataset_ops.stats_aggregator_handle()
# TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
The returned tensor will contain a serialized `tf.summary.Summary` protocol
buffer, which can be used with the standard TensorBoard logging facilities.
Returns:
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
"""
return gen_dataset_ops.stats_aggregator_summary(self._resource)
......@@ -21,110 +21,18 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsAggregator")
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
iterator (see below). For example, to record the latency of producing each
element by iterating over a dataset:
```python
dataset = ...
dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
```
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
stats_aggregator = stats_ops.StatsAggregator()
dataset = ...
# Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
dataset = dataset.apply(
tf.data.experimental.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_one_shot_iterator()
```
To get a protocol buffer summary of the currently aggregated statistics,
use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
so that the summaries will be included with any existing summaries.
```python
stats_aggregator = stats_ops.StatsAggregator()
# ...
stats_summary = stats_aggregator.get_summary()
tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
```
Note: This interface is experimental and expected to change. In particular,
we expect to add other implementations of `StatsAggregator` that provide
different ways of exporting statistics, and add more types of statistics.
"""
def __init__(self):
"""Creates a `StatsAggregator`."""
self._resource = gen_dataset_ops.stats_aggregator_handle()
# TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
The returned tensor will contain a serialized `tf.summary.Summary` protocol
buffer, which can be used with the standard TensorBoard logging facilities.
Returns:
A scalar string `tf.Tensor` that summarizes the aggregated statistics.
"""
return gen_dataset_ops.stats_aggregator_summary(self._resource)
class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
def __init__(self, input_dataset, stats_aggregator, tag, prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
self._tag = tag
self._prefix = prefix
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
self._tag,
self._prefix,
**dataset_ops.flat_structure(self))
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types
@property
def output_classes(self):
return self._input_dataset.output_classes
@tf_export("data.experimental.set_stats_aggregator")
def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""):
@deprecation.deprecated(None, "Use `tf.data.experimental.StatsOptions`.")
def set_stats_aggregator(stats_aggregator, prefix="", counter_prefix=""):
"""Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
tag: (Optional) String, all statistics recorded for the input `dataset`
will have given `tag` prepend with the name.
prefix: (Optional) String, all statistics recorded for the input `dataset`
will have given `prefix` prepend with the name.
counter_prefix: (Optional) String, all statistics recorded as `counters`
will have the given `prefix` for the counter. Defaults to "/tensorflow".
......@@ -134,8 +42,8 @@ def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""):
"""
def _apply_fn(dataset):
return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag,
counter_prefix)
return dataset_ops._SetStatsAggregatorDataset( # pylint: disable=protected-access
dataset, stats_aggregator, prefix, counter_prefix)
return _apply_fn
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StatsOptions to configure stats aggregation options for `tf.data` pipelines.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsOptions")
class StatsOptions(object):
"""Represents options for collecting dataset stats using `StatsAggregator`.
To apply `StatsOptions` with a `tf.data.Dataset` object, use the following
pattern:
```python
aggretator = tf.data.experimental.StatsAggregator()
options = dataset_ops.Options()
options.experimental_stats = tf.data.experimental.StatsOptions()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
```
Note: a `StatsAggregator` object can be attached either duing construction or
can be provided later like in above example.
```python
aggretator = tf.data.experimental.StatsAggregator()
# attach aggregator during construction
options.experimental_stats = tf.data.experimental.StatsOptions(aggregator)
.....
```
"""
for _name, _ty, _default, _docstring in [
("aggregator", stats_aggregator.StatsAggregator, None,
"Associate the given statistics options with the dataset pipeline."),
("prefix", str, "",
"Prefix to prepend all statistics recorded for the input `dataset` with."
),
("counter_prefix", str, "",
"Prefix for the statistics recorded as counter."),
("latency_all_edges", bool, True,
"Whether to add latency measurements on all edges."),
]:
def _make_getter(name): # pylint: disable=no-self-argument
def getter(self):
return getattr(self, "_" + name)
return getter
def _make_setter(name, ty): # pylint: disable=no-self-argument
def setter(self, value):
if not isinstance(value, ty):
raise TypeError(
"Attempting to set the option %s to incompatible value: %r when "
"it expects %r" % (name, value, ty))
setattr(self, "_" + name, value)
return setter
vars()["_" + _name] = _default
vars()[_name] = property(
_make_getter(_name), _make_setter(_name, _ty), _default, _docstring)
def __init__(self, aggregator=None):
if aggregator:
self.aggregator = aggregator
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
return str(self.__dict__)
......@@ -25,6 +25,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
......
......@@ -25,6 +25,7 @@ import numpy as np
import six
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
......@@ -101,6 +102,8 @@ class Dataset(object):
return options
def _apply_options(self):
"""Apply options, such as optimization configuration, to the dataset."""
dataset = self
options = self.options()
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
......@@ -108,6 +111,11 @@ class Dataset(object):
dataset = _OptimizeDataset(dataset, static_optimizations)
if options.experimental_autotune is not False:
dataset = _ModelDataset(dataset)
if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long
dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access
dataset, options.experimental_stats.aggregator,
options.experimental_stats.prefix,
options.experimental_stats.counter_prefix)
return dataset
def make_initializable_iterator(self, shared_name=None):
......@@ -1411,8 +1419,8 @@ class Options(object):
("experimental_hoist_random_uniform", bool,
"Whether to hoist `tf.random_uniform()` ops out of map transformations."
),
("experimental_latency_all_edges", bool,
"Whether to add latency measurements on all edges."),
("experimental_stats", stats_options.StatsOptions,
"Associate the given statistics options with the dataset pipeline."),
("experimental_map_and_batch_fusion", bool,
"Whether to fuse map and batch transformations."),
("experimental_map_and_filter_fusion", bool,
......@@ -1442,8 +1450,8 @@ class Options(object):
def setter(self, value):
if not isinstance(value, ty):
raise TypeError(
"Attempting to set the option %s to incompatible value: %r" %
(name, value))
"Attempting to set the option %s to incompatible value: %r when "
"it expects %r" % (name, value, ty))
setattr(self, "_" + name, value)
return setter
......@@ -1467,10 +1475,15 @@ class Options(object):
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
experimental_optimizations = [
"filter_fusion", "hoist_random_uniform", "latency_all_edges",
"map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
"map_parallelization", "map_vectorization", "noop_elimination",
"shuffle_and_repeat_fusion"
"filter_fusion",
"hoist_random_uniform",
"map_and_batch_fusion",
"map_and_filter_fusion",
"map_fusion",
"map_parallelization",
"map_vectorization",
"noop_elimination",
"shuffle_and_repeat_fusion",
]
result = []
for exp_opt in experimental_optimizations:
......@@ -1481,6 +1494,10 @@ class Options(object):
result.append("make_numa_aware")
if getattr(self, "experimental_deterministic") is False:
result.append("make_sloppy")
experimental_stats_options = getattr(self, "experimental_stats")
if experimental_stats_options and getattr(experimental_stats_options,
"latency_all_edges"):
result.append("latency_all_edges")
return result
def merge(self, options):
......@@ -1506,7 +1523,6 @@ class Options(object):
"experimental_deterministic",
"experimental_filter_fusion",
"experimental_hoist_random_uniform",
"experimental_latency_all_edges",
"experimental_map_and_batch_fusion",
"experimental_map_and_filter_fusion",
"experimental_map_fusion",
......@@ -1515,6 +1531,7 @@ class Options(object):
"experimental_noop_elimination",
"experimental_numa_aware",
"experimental_shuffle_and_repeat_fusion",
"experimental_stats",
]:
this = getattr(result, name)
that = getattr(other, name)
......@@ -3068,3 +3085,34 @@ class _OptimizeDataset(UnaryDataset):
@property
def output_types(self):
return self._input_dataset.output_types
class _SetStatsAggregatorDataset(UnaryDataset):
"""A `Dataset` that acts as an identity, and sets stats aggregator."""
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = aggregator
self._prefix = prefix
self._counter_prefix = counter_prefix
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
self._prefix,
self._counter_prefix,
**flat_structure(self))
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types
@property
def output_classes(self):
return self._input_dataset.output_classes
......@@ -18,10 +18,6 @@ tf_class {
name: "experimental_hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "experimental_latency_all_edges"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_batch_fusion"
mtype: "<type \'property\'>"
......@@ -54,6 +50,10 @@ tf_class {
name: "experimental_shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_stats"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
......
path: "tensorflow.data.experimental.StatsAggregator"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
......
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
mtype: "<type \'property\'>"
}
member {
name: "counter_prefix"
mtype: "<type \'property\'>"
}
member {
name: "latency_all_edges"
mtype: "<type \'property\'>"
}
member {
name: "prefix"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
......@@ -32,6 +32,10 @@ tf_module {
name: "StatsAggregator"
mtype: "<type \'type\'>"
}
member {
name: "StatsOptions"
mtype: "<type \'type\'>"
}
member {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
......@@ -124,10 +128,6 @@ tf_module {
name: "scan"
argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_stats_aggregator"
argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], "
}
member_method {
name: "shuffle_and_repeat"
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
......
......@@ -18,10 +18,6 @@ tf_class {
name: "experimental_hoist_random_uniform"
mtype: "<type \'property\'>"
}
member {
name: "experimental_latency_all_edges"
mtype: "<type \'property\'>"
}
member {
name: "experimental_map_and_batch_fusion"
mtype: "<type \'property\'>"
......@@ -54,6 +50,10 @@ tf_class {
name: "experimental_shuffle_and_repeat_fusion"
mtype: "<type \'property\'>"
}
member {
name: "experimental_stats"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
......
path: "tensorflow.data.experimental.StatsAggregator"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_aggregator.StatsAggregator\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
......
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
mtype: "<type \'property\'>"
}
member {
name: "counter_prefix"
mtype: "<type \'property\'>"
}
member {
name: "latency_all_edges"
mtype: "<type \'property\'>"
}
member {
name: "prefix"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
......@@ -32,6 +32,10 @@ tf_module {
name: "StatsAggregator"
mtype: "<type \'type\'>"
}
member {
name: "StatsOptions"
mtype: "<type \'type\'>"
}
member {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
......@@ -124,10 +128,6 @@ tf_module {
name: "scan"
argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_stats_aggregator"
argspec: "args=[\'stats_aggregator\', \'tag\', \'counter_prefix\'], varargs=None, keywords=None, defaults=[\'\', \'\'], "
}
member_method {
name: "shuffle_and_repeat"
argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册