提交 fa5590b4 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Options-related changes.

This CL:
- refactors all options classes to use a shared options utility
- introduces `tf.data.experimental.ThreadingOptions` options for threading configuration and surfaces it through `experimental_threading` of `tf.data.Options`

PiperOrigin-RevId: 222462977
上级 78ef03f3
op {
graph_op_name: "ExperimentalMaxIntraOpParallelismDataset"
in_arg {
name: "max_intra_op_parallelism"
description: <<END
Identifies the maximum intra-op parallelism to use.
END
}
summary: <<END
Creates a dataset that overrides the maximum intra-op parallelism.
END
visibility: HIDDEN
}
op {
graph_op_name: "ExperimentalPrivateThreadPoolDataset"
in_arg {
name: "num_threads"
description: <<END
Identifies the number of threads to use for the private threadpool.
END
}
summary: <<END
Creates a dataset that uses a custom thread pool to compute `input_dataset`.
END
visibility: HIDDEN
}
......@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
......@@ -225,6 +227,221 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MaxIntraOpParallelismDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 max_intra_op_parallelism;
OP_REQUIRES_OK(ctx,
ParseScalarArgument<int64>(ctx, "max_intra_op_parallelism",
&max_intra_op_parallelism));
OP_REQUIRES(
ctx, max_intra_op_parallelism >= 0,
errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
*output = new Dataset(ctx, input, max_intra_op_parallelism);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
int max_intra_op_parallelism)
: DatasetBase(DatasetContext(ctx)),
input_(input),
max_intra_op_parallelism_(max_intra_op_parallelism) {
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::MaxIntraOpParallelism")}));
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override {
return "MaxIntraOpParallelismDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* max_intra_op_parallelism_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(max_intra_op_parallelism_,
&max_intra_op_parallelism_node));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, max_intra_op_parallelism_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
IteratorContext::Params params(ctx);
auto max_parallelism = dataset()->max_intra_op_parallelism_;
params.runner = std::bind(
[max_parallelism](
const std::function<void(std::function<void()>)>& runner,
std::function<void()> fn) {
std::function<void()> scoped_fn = std::bind(
[max_parallelism](const std::function<void()>& fn) {
ScopedPerThreadMaxParallelism scope(max_parallelism);
fn();
},
std::move(fn));
(runner)(std::move(scoped_fn));
},
std::move(*ctx->runner()), std::placeholders::_1);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
const int max_intra_op_parallelism_;
};
};
class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
public:
explicit PrivateThreadPoolDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 num_threads;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "num_threads", &num_threads));
OP_REQUIRES(ctx, num_threads >= 1,
errors::InvalidArgument("`num_threads` must be >= 1"));
*output = new Dataset(ctx, input, num_threads);
}
private:
class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
: DatasetBase(DatasetContext(ctx)),
input_(input),
num_threads_(num_threads) {
thread_pool_ = MakeUnique<thread::ThreadPool>(
ctx->env(), ThreadOptions{}, "tf_data_private_threadpool",
num_threads,
/*low_latency_hint=*/false);
input_->Ref();
}
~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::PrivateThreadPool")}));
}
const DataTypeVector& output_dtypes() const override {
return input_->output_dtypes();
}
const std::vector<PartialTensorShape>& output_shapes() const override {
return input_->output_shapes();
}
string DebugString() const override {
return "PrivateThreadPoolDatasetOp::Dataset";
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* num_threads_node = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node));
TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
thread::ThreadPool* pool = dataset()->thread_pool_.get();
IteratorContext::Params params(ctx);
params.runner = [pool](std::function<void()> c) {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = dataset()->num_threads_;
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
return model::MakeKnownRatioNode(std::move(args),
/*ratio=*/1);
}
private:
std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
const int64 num_threads_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
};
REGISTER_KERNEL_BUILDER(
Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
MaxIntraOpParallelismDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
PrivateThreadPoolDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
REGISTER_KERNEL_BUILDER(
......
......@@ -140,6 +140,22 @@ REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
.Input("function_buffer_resource: resource")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
.Input("input_dataset: variant")
.Input("max_intra_op_parallelism: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
.Input("input_dataset: variant")
.Input("num_threads: int64")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
......
......@@ -32,6 +32,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@StatsAggregator
@@StatsOptions
@@TFRecordWriter
@@ThreadingOptions
@@bucket_by_sequence_length
@@choose_from_datasets
......@@ -101,6 +102,7 @@ from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repe
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
from tensorflow.python.data.experimental.ops.stats_options import StatsOptions
from tensorflow.python.data.experimental.ops.threading_options import ThreadingOptions
from tensorflow.python.data.experimental.ops.unique import unique
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
......
......@@ -60,7 +60,8 @@ class LatencyAllEdgesTest(stats_dataset_test_base.StatsDatasetTestBase):
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions(aggregator)
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
self.assertDatasetProduces(
dataset,
......
......@@ -22,6 +22,7 @@ import threading
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
......@@ -35,18 +36,7 @@ from tensorflow.python.platform import test
class OverrideThreadpoolTest(test_base.DatasetTestBase,
parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
("2", 2, None),
("3", 4, None),
("4", 8, None),
("5", 16, None),
("6", 4, -1),
("7", 4, 0),
("8", 4, 1),
("9", 4, 4),
)
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def _testNumThreadsHelper(self, num_threads, override_threadpool_fn):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
......@@ -60,14 +50,7 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
dataset_ops.Dataset.range(1000).map(
lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
num_parallel_calls=32).apply(unique.unique()))
dataset = threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name="private_thread_pool_%d" % num_threads))
dataset = override_threadpool_fn(dataset)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
......@@ -79,12 +62,64 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
thread_ids.append(sess.run(next_element))
except errors.OutOfRangeError:
pass
self.assertEqual(len(thread_ids), len(set(thread_ids)))
self.assertGreater(len(thread_ids), 0)
# NOTE(mrry): We don't control the thread pool scheduling, and
# so cannot guarantee that all of the threads in the pool will
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
self.assertLen(thread_ids, len(set(thread_ids)))
self.assertNotEmpty(thread_ids)
if num_threads:
# NOTE(mrry): We don't control the thread pool scheduling, and
# so cannot guarantee that all of the threads in the pool will
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
@parameterized.named_parameters(
("1", 1, None),
("2", 2, None),
("3", 4, None),
("4", 8, None),
("5", 16, None),
("6", 4, -1),
("7", 4, 0),
("8", 4, 1),
("9", 4, 4),
)
def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset):
return threadpool.override_threadpool(
dataset,
threadpool.PrivateThreadPool(
num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name="private_thread_pool_%d" % num_threads))
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
@parameterized.named_parameters(
("1", 1, None),
("2", 2, None),
("3", 4, None),
("4", 8, None),
("5", 16, None),
("6", None, 0),
("7", None, 1),
("8", None, 4),
("9", 4, 0),
("10", 4, 1),
("11", 4, 4),
("12", None, None),
)
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def override_threadpool_fn(dataset):
t_options = threading_options.ThreadingOptions()
if max_intra_op_parallelism is not None:
t_options.max_intra_op_parallelism = max_intra_op_parallelism
if num_threads is not None:
t_options.private_threadpool_size = num_threads
options = dataset_ops.Options()
options.experimental_threading = t_options
return dataset.with_options(options)
self._testNumThreadsHelper(num_threads, override_threadpool_fn)
if __name__ == "__main__":
......
......@@ -45,22 +45,18 @@ def function_set_stats_aggregator(dataset,
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
options = dataset_ops.Options()
options.experimental_stats = stats_options.StatsOptions(aggregator)
options.experimental_stats = stats_options.StatsOptions()
options.experimental_stats.aggregator = aggregator
options.experimental_stats.prefix = prefix
options.experimental_stats.counter_prefix = counter_prefix
options.experimental_stats.latency_all_edges = False
if prefix:
options.experimental_stats.prefix = prefix
if counter_prefix:
options.experimental_stats.counter_prefix = counter_prefix
return dataset.with_options(options)
@parameterized.named_parameters(
dict(
testcase_name="SetStatsAggregator",
dataset_transformation=function_set_stats_aggregator),
dict(
testcase_name="StatsOptions",
dataset_transformation=function_apply_options))
("SetStatsAggregator", function_set_stats_aggregator),
("StatsOptions", function_apply_options),
)
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self, dataset_transformation):
......
......@@ -188,6 +188,17 @@ py_library(
],
)
py_library(
name = "map_defun",
srcs = ["map_defun.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
],
)
py_library(
name = "optimization",
srcs = ["optimization.py"],
......@@ -217,17 +228,6 @@ py_library(
],
)
py_library(
name = "map_defun",
srcs = ["map_defun.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
],
)
py_library(
name = "resampling",
srcs = ["resampling.py"],
......@@ -303,6 +303,18 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":stats_aggregator",
"//tensorflow/python:util",
"//tensorflow/python/data/util:options",
],
)
py_library(
name = "threading_options",
srcs = ["threading_options.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python/data/util:options",
],
)
......@@ -313,9 +325,8 @@ py_library(
deps = [
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
],
)
......
......@@ -20,11 +20,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops import stats_aggregator
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsOptions")
class StatsOptions(object):
class StatsOptions(options.OptionsBase):
"""Represents options for collecting dataset stats using `StatsAggregator`.
To apply `StatsOptions` with a `tf.data.Dataset` object, use the following
......@@ -52,52 +53,29 @@ class StatsOptions(object):
```
"""
for _name, _ty, _default, _docstring in [
("aggregator", stats_aggregator.StatsAggregator, None,
"Associate the given statistics options with the dataset pipeline."),
("prefix", str, "",
"Prefix to prepend all statistics recorded for the input `dataset` with."
),
("counter_prefix", str, "",
"Prefix for the statistics recorded as counter."),
("latency_all_edges", bool, True,
"Whether to add latency measurements on all edges."),
]:
def _make_getter(name): # pylint: disable=no-self-argument
def getter(self):
return getattr(self, "_" + name)
return getter
def _make_setter(name, ty): # pylint: disable=no-self-argument
def setter(self, value):
if not isinstance(value, ty):
raise TypeError(
"Attempting to set the option %s to incompatible value: %r when "
"it expects %r" % (name, value, ty))
setattr(self, "_" + name, value)
return setter
vars()["_" + _name] = _default
vars()[_name] = property(
_make_getter(_name), _make_setter(_name, _ty), _default, _docstring)
def __init__(self, aggregator=None):
if aggregator:
self.aggregator = aggregator
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
return str(self.__dict__)
aggregator = options.create_option(
name="aggregator",
ty=stats_aggregator.StatsAggregator,
docstring=
"Associates the given statistics aggregator with the dataset pipeline.")
prefix = options.create_option(
name="prefix",
ty=str,
docstring=
"Prefix to prepend all statistics recorded for the input `dataset` with.",
default="")
counter_prefix = options.create_option(
name="counter_prefix",
ty=str,
docstring=
"Prefix for the statistics recorded as counter.",
default="")
latency_all_edges = options.create_option(
name="latency_all_edges",
ty=bool,
docstring=
"Whether to add latency measurements on all edges.",
default=True)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for controlling threading in `tf.data` pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.ThreadingOptions")
class ThreadingOptions(options.OptionsBase):
"""Represents options for dataset threading.
To apply `ThreadingOptions` to a `dataset` object, use the following pattern:
```python
options = dataset_ops.Options()
options.experimental_threading = tf.data.experimental.ThreadingOptions()
options.experimental_threading.private_threadpool_size = 10
dataset = dataset.with_options(options)
```
"""
max_intra_op_parallelism = options.create_option(
name="max_intra_op_parallelism",
ty=int,
docstring=
"If set, it overrides the maximum degree of intra-op parallelism.")
private_threadpool_size = options.create_option(
name="private_threadpool_size",
ty=int,
docstring=
"If set, the dataset will use a private threadpool of the given size.",
default=None)
......@@ -239,7 +239,7 @@ class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
options2 = dataset_ops.Options()
options2.experimental_autotune = False
with self.assertRaisesRegexp(ValueError,
"Cannot merge incompatible values of option"):
"Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
def testOptionsMergeOptionsFromMultipleInputs(self):
......
......@@ -14,6 +14,7 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:math_ops",
......@@ -26,7 +27,9 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:options",
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/data/util:structure",
......
......@@ -27,8 +27,10 @@ import six
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import options as options_lib
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.data.util import structure as structure_lib
......@@ -45,6 +47,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
......@@ -107,6 +110,14 @@ class DatasetV2(object):
dataset = self
options = self.options()
if options.experimental_threading is not None:
t_options = options.experimental_threading
if t_options.private_threadpool_size is not None:
dataset = _PrivateThreadPoolDataset(dataset,
t_options.private_threadpool_size)
if t_options.max_intra_op_parallelism is not None:
dataset = _MaxIntraOpParallelismDataset(
dataset, t_options.max_intra_op_parallelism)
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
if static_optimizations:
dataset = _OptimizeDataset(dataset, static_optimizations)
......@@ -1371,10 +1382,9 @@ class DatasetV2(object):
def with_options(self, options):
"""Returns a new `tf.data.Dataset` with the given options set.
The options are "global" in the sense they apply to the entire input
pipeline in which the `with_options` transformation is used. If options are
set multiple times, they are merged if possible (see
`tf.data.Options.merge()` for details).
The options are "global" in the sense they apply to the entire dataset.
If options are set multiple times, they are merged as long as different
options do not use different non-default values.
Args:
options: A `tf.data.Options` that identifies the options the use.
......@@ -1383,7 +1393,7 @@ class DatasetV2(object):
Dataset: A `Dataset` with the given options.
Raises:
ValueError: if options are set more than once
ValueError: when an option is set more than once to a non-default value
"""
return _OptionsDataset(self, options)
......@@ -1571,7 +1581,7 @@ class DatasetV1Adapter(DatasetV1):
@tf_export("data.Options")
class Options(object):
class Options(options_lib.OptionsBase):
"""Represents options for tf.data.Dataset.
An `Options` object can be for instance used to control which static
......@@ -1579,69 +1589,81 @@ class Options(object):
tune the parallelism of operations such as `tf.data.Dataset.map` or
`tf.data.Dataset.interleave`.
"""
for _name, _ty, _docstring in [
("experimental_autotune", bool,
"Whether to dynamically adjust the values of tunable parameters (e.g. "
"degrees of parallelism)."),
("experimental_deterministic", bool,
"Whether the outputs need to be produced in deterministic order."),
("experimental_filter_fusion", bool,
"Whether to fuse filter transformations."),
("experimental_hoist_random_uniform", bool,
"Whether to hoist `tf.random_uniform()` ops out of map transformations."
),
("experimental_stats", stats_options.StatsOptions,
"Associate the given statistics options with the dataset pipeline."),
("experimental_map_and_batch_fusion", bool,
"Whether to fuse map and batch transformations."),
("experimental_map_and_filter_fusion", bool,
"Whether to fuse map and filter transformations."),
("experimental_map_fusion", bool, "Whether to fuse map transformations."),
("experimental_map_parallelization", bool,
"Whether to parallelize stateless map transformations."),
("experimental_map_vectorization", bool,
"Whether to vectorize map transformations."),
("experimental_noop_elimination", bool,
"Whether to eliminate no-op transformations."),
("experimental_shuffle_and_repeat_fusion", bool,
"Whether to fuse shuffle and repeat transformations."),
("experimental_numa_aware", bool,
"Whether to use NUMA-aware operations."),
]:
def _make_getter(name): # pylint: disable=no-self-argument
def getter(self):
return getattr(self, "_" + name)
return getter
def _make_setter(name, ty): # pylint: disable=no-self-argument
def setter(self, value):
if not isinstance(value, ty):
raise TypeError(
"Attempting to set the option %s to incompatible value: %r when "
"it expects %r" % (name, value, ty))
setattr(self, "_" + name, value)
return setter
vars()["_" + _name] = None
vars()[_name] = property(
_make_getter(_name), _make_setter(_name, _ty), None, _docstring)
def __init__(self):
pass
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
experimental_autotune = options_lib.create_option(
name="experimental_autotune",
ty=bool,
docstring=
"Whether to dynamically adjust the values of tunable parameters (e.g. "
"degrees of parallelism).")
experimental_deterministic = options_lib.create_option(
name="experimental_deterministic",
ty=bool,
docstring=
"Whether to dynamically adjust the values of tunable parameters (e.g. "
"degrees of parallelism).")
experimental_filter_fusion = options_lib.create_option(
name="experimental_filter_fusion",
ty=bool,
docstring="Whether to fuse filter transformations.")
experimental_hoist_random_uniform = options_lib.create_option(
name="experimental_hoist_random_uniform",
ty=bool,
docstring=
"Whether to hoist `tf.random_uniform()` ops out of map transformations.")
experimental_map_and_batch_fusion = options_lib.create_option(
name="experimental_map_and_batch_fusion",
ty=bool,
docstring="Whether to fuse map and batch transformations.")
experimental_map_and_filter_fusion = options_lib.create_option(
name="experimental_map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map and filter transformations.")
experimental_map_fusion = options_lib.create_option(
name="experimental_map_and_filter_fusion",
ty=bool,
docstring="Whether to fuse map transformations.")
experimental_map_parallelization = options_lib.create_option(
name="experimental_map_parallelization",
ty=bool,
docstring="Whether to parallelize stateless map transformations.")
experimental_map_vectorization = options_lib.create_option(
name="experimental_map_vectorization",
ty=bool,
docstring="Whether to vectorize map transformations.")
experimental_noop_elimination = options_lib.create_option(
name="experimental_noop_elimination",
ty=bool,
docstring="Whether to eliminate no-op transformations.")
experimental_numa_aware = options_lib.create_option(
name="experimental_numa_aware",
ty=bool,
docstring="Whether to use NUMA-aware operations.")
experimental_shuffle_and_repeat_fusion = options_lib.create_option(
name="experimental_shuffle_and_repeat_fusion",
ty=bool,
docstring="Whether to fuse shuffle and repeat transformations.")
experimental_stats = options_lib.create_option(
name="experimental_stats",
ty=stats_options.StatsOptions,
docstring="Associates the given statistics options with the dataset.")
experimental_threading = options_lib.create_option(
name="experimental_threading",
ty=threading_options.ThreadingOptions,
docstring="Associates the given threading options with the dataset.")
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
......@@ -1687,32 +1709,7 @@ class Options(object):
New `tf.data.Options()` object which is the result of merging self with
the input `tf.data.Options`.
"""
result = Options()
for other in [self, options]:
for name in [
"experimental_autotune",
"experimental_deterministic",
"experimental_filter_fusion",
"experimental_hoist_random_uniform",
"experimental_map_and_batch_fusion",
"experimental_map_and_filter_fusion",
"experimental_map_fusion",
"experimental_map_parallelization",
"experimental_map_vectorization",
"experimental_noop_elimination",
"experimental_numa_aware",
"experimental_shuffle_and_repeat_fusion",
"experimental_stats",
]:
this = getattr(result, name)
that = getattr(other, name)
if that is not None:
if this is None:
setattr(result, name, that)
elif this != that:
raise ValueError(
"Cannot merge incompatible values of option: %s" % (name))
return result
return options_lib.merge_options(self, options)
class DatasetSource(DatasetV2):
......@@ -3065,7 +3062,7 @@ class _OptimizeDataset(UnaryUnchangedStructureDataset):
class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, and sets stats aggregator."""
"""A `Dataset` that acts as an identity, and sets a stats aggregator."""
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
......@@ -3081,3 +3078,37 @@ class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
self._prefix,
self._counter_prefix,
**flat_structure(self))
class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, overriding intra-op parallelism."""
def __init__(self, input_dataset, max_intra_op_parallelism):
super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._max_intra_op_parallelism = ops.convert_to_tensor(
max_intra_op_parallelism,
dtype=dtypes.int64,
name="max_intra_op_parallelism")
def _as_variant_tensor(self):
return ged_ops.experimental_max_intra_op_parallelism_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._max_intra_op_parallelism,
**flat_structure(self))
class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
"""A `Dataset` that acts as an identity, setting a private threadpool."""
def __init__(self, input_dataset, num_threads):
super(_PrivateThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._num_threads = ops.convert_to_tensor(
num_threads, dtype=dtypes.int64, name="num_threads")
def _as_variant_tensor(self):
return ged_ops.experimental_private_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._num_threads,
**flat_structure(self))
......@@ -97,6 +97,23 @@ py_test(
],
)
py_library(
name = "options",
srcs = ["options.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "options_test",
size = "small",
srcs = ["options_test.py"],
srcs_version = "PY2AND3",
deps = [
":options",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "convert",
srcs = ["convert.py"],
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for tf.data options."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def _internal_attr_name(name):
return "_" + name
class OptionsBase(object):
"""Base class for representing a set of tf.data options.
Attributes:
_options: Stores the option values.
"""
def __init__(self):
self._options = {}
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
for name in set(self._options) | set(other._options): # pylint: disable=protected-access
if getattr(self, name) != getattr(other, name):
return False
return True
def __ne__(self, other):
if isinstance(other, self.__class__):
return not self.__eq__(other)
else:
return NotImplemented
def create_option(name, ty, docstring, default=None):
"""Creates a type-checked property.
Args:
name: the name to use
ty: the type to use
docstring: the docstring to use
default: the default value to use
Returns:
A type-checked property.
"""
def get_fn(self):
return self._options.get(name, default) # pylint: disable=protected-access
def set_fn(self, value):
if not isinstance(value, ty):
raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
(name, ty, value, type(value)))
self._options[name] = value # pylint: disable=protected-access
return property(get_fn, set_fn, None, docstring)
def merge_options(*options_list):
"""Merges the given options, returning the result as a new options object.
The input arguments are expected to have a matching type that derives from
`OptionsBase` (and thus each represent a set of options). The method outputs
an object of the same type created by merging the sets of options represented
by the input arguments.
The sets of options can be merged as long as there does not exist an option
with different non-default values.
If an option is an instance of `OptionsBase` itself, then this method is
applied recursively to the set of options represented by this option.
Args:
*options_list: options to merge
Raises:
TypeError: if the input arguments are incompatible or not derived from
`OptionsBase`
ValueError: if the given options cannot be merged
Returns:
A new options object which is the result of merging the given options.
"""
if len(options_list) < 1:
raise ValueError("At least one options should be provided")
result_type = type(options_list[0])
for options in options_list:
if not isinstance(options, result_type):
raise TypeError("Incompatible options type: %r vs %r" % (type(options),
result_type))
if not isinstance(options_list[0], OptionsBase):
raise TypeError("The inputs should inherit from `OptionsBase`")
default_options = result_type()
result = result_type()
for options in options_list:
# Iterate over all set options and merge the into the result.
for name in options._options: # pylint: disable=protected-access
this = getattr(result, name)
that = getattr(options, name)
default = getattr(default_options, name)
if that == default:
continue
elif this == default:
setattr(result, name, that)
elif isinstance(this, OptionsBase):
setattr(result, name, merge_options(this, that))
elif this != that:
raise ValueError(
"Cannot merge incompatible values (%r and %r) of option: %s" %
(this, that, name))
return result
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for dataset options utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import options
from tensorflow.python.platform import test
class _TestOptions(options.OptionsBase):
x = options.create_option(
name="x", ty=int, docstring="the answer to everything", default=42)
y = options.create_option(
name="y", ty=float, docstring="a tasty pie", default=3.14)
class _NestedTestOptions(options.OptionsBase):
opts = options.create_option(
name="opts", ty=_TestOptions, docstring="nested options")
class OptionsTest(test.TestCase):
def testDocumentation(self):
self.assertEqual(_TestOptions.x.__doc__, "the answer to everything")
self.assertEqual(_TestOptions.y.__doc__, "a tasty pie")
def testCreateOption(self):
opts = _TestOptions()
self.assertEqual(opts.x, 42)
self.assertEqual(opts.y, 3.14)
self.assertIsInstance(opts.x, int)
self.assertIsInstance(opts.y, float)
opts.x = 0
self.assertEqual(opts.x, 0)
with self.assertRaises(TypeError):
opts.x = 3.14
opts.y = 0.0
self.assertEqual(opts.y, 0.0)
with self.assertRaises(TypeError):
opts.y = 42
def testMergeOptions(self):
options1, options2 = _TestOptions(), _TestOptions()
with self.assertRaises(ValueError):
options.merge_options()
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.x, 42)
self.assertEqual(merged_options.y, 3.14)
options1.x = 0
options2.y = 0.0
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.x, 0)
self.assertEqual(merged_options.y, 0.0)
def testMergeNestedOptions(self):
options1, options2 = _NestedTestOptions(), _NestedTestOptions()
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.opts, None)
options1.opts = _TestOptions()
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.opts, _TestOptions())
options2.opts = _TestOptions()
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.opts, _TestOptions())
options1.opts.x = 0
options2.opts.y = 0.0
merged_options = options.merge_options(options1, options2)
self.assertEqual(merged_options.opts.x, 0)
self.assertEqual(merged_options.opts.y, 0.0)
def testMergeOptionsInvalid(self):
with self.assertRaises(TypeError):
options.merge_options(0)
options1, options2 = _TestOptions(), _NestedTestOptions()
with self.assertRaises(TypeError):
options.merge_options(options1, options2)
if __name__ == "__main__":
test.main()
path: "tensorflow.data.Options"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_autotune"
......@@ -54,6 +55,10 @@ tf_class {
name: "experimental_stats"
mtype: "<type \'property\'>"
}
member {
name: "experimental_threading"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
......
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
......@@ -20,6 +21,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
path: "tensorflow.data.experimental.ThreadingOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "max_intra_op_parallelism"
mtype: "<type \'property\'>"
}
member {
name: "private_threadpool_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -40,6 +40,10 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "ThreadingOptions"
mtype: "<type \'type\'>"
}
member_method {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
......
path: "tensorflow.data.Options"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_autotune"
......@@ -54,6 +55,10 @@ tf_class {
name: "experimental_stats"
mtype: "<type \'property\'>"
}
member {
name: "experimental_threading"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
......
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
......@@ -20,6 +21,6 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
path: "tensorflow.data.experimental.ThreadingOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "max_intra_op_parallelism"
mtype: "<type \'property\'>"
}
member {
name: "private_threadpool_size"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -40,6 +40,10 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "ThreadingOptions"
mtype: "<type \'type\'>"
}
member_method {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册