提交 7ea3ea31 编写于 作者: F Frank Chen 提交者: TensorFlower Gardener

Add option to turn off autosharding under distribution strategy

PiperOrigin-RevId: 246027714
上级 eea4fd60
......@@ -26,6 +26,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@CheckpointInputPipelineHook
@@CsvDataset
@@DatasetStructure
@@DistributeOptions
@@MapVectorizationOptions
@@NestedStructure
@@OptimizationOptions
......@@ -85,7 +86,6 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
from tensorflow.python.data.experimental.ops.batching import map_and_batch
from tensorflow.python.data.experimental.ops.batching import map_and_batch_with_legacy_function
......@@ -94,6 +94,7 @@ from tensorflow.python.data.experimental.ops.cardinality import cardinality
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY
from tensorflow.python.data.experimental.ops.counter import Counter
from tensorflow.python.data.experimental.ops.distribute_options import DistributeOptions
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
......
......@@ -152,6 +152,16 @@ py_library(
],
)
py_library(
name = "distribute_options",
srcs = ["distribute_options.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python/data/util:options",
],
)
py_library(
name = "snapshot",
srcs = [
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental API for controlling distribution in `tf.data` pipelines."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.DistributeOptions")
class DistributeOptions(options.OptionsBase):
"""Represents options for distributed data processing.
You can set the distribution options of a dataset through the
`experimental_distribute` property of `tf.data.Options`; the property is
an instance of `tf.data.experimental.DistributeOptions`.
```python
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
dataset = dataset.with_options(options)
```
"""
auto_shard = options.create_option(
name="auto_shard",
ty=bool,
docstring=
"Whether the dataset should be automatically sharded when processed"
"in a distributed fashion. This is applicable when using Keras with "
"multi-worker/TPU distribution strategy, and by "
"using strategy.experimental_distribute_dataset(). In other cases, this "
"option does nothing. If None, defaults to True.",
default_factory=lambda: True)
......@@ -26,6 +26,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:distribute_options",
"//tensorflow/python/data/experimental/ops:optimization_options",
"//tensorflow/python/data/experimental/ops:stats_options",
"//tensorflow/python/data/experimental/ops:threading_options",
......
......@@ -30,6 +30,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
......@@ -2028,6 +2029,14 @@ class Options(options_lib.OptionsBase):
"Whether the outputs need to be produced in deterministic order. If None,"
" defaults to True.")
experimental_distribute = options_lib.create_option(
name="experimental_distribute",
ty=distribute_options.DistributeOptions,
docstring=
"The distribution options associated with the dataset. See "
"`tf.data.experimental.DistributeOptions` for more details.",
default_factory=distribute_options.DistributeOptions)
experimental_optimization = options_lib.create_option(
name="experimental_optimization",
ty=optimization_options.OptimizationOptions,
......
......@@ -433,9 +433,27 @@ class Strategy(object):
def experimental_distribute_dataset(self, dataset):
"""Distributes a tf.data.Dataset instance provided via `dataset`.
Data from the given dataset will be distributed evenly across all the
compute replicas. This function assumes that the input dataset is batched
by the global batch size.
In a multi-worker setting, we will first attempt to distribute the dataset
by attempting to detect whether the dataset is being created out of
ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so,
attempting to shard the input files. Note that there has to be at least one
input file per worker. If you have less than one input file per worker, we
suggest that you should disable distributing your dataset using the method
below.
If that attempt is unsuccessful (e.g. the dataset is created from a
Dataset.range), we will shard the dataset evenly at the end by appending a
`.shard` operation to the end of the processing pipeline. This will cause
the entire preprocessing pipeline for all the data to be run on every
worker, and each worker will do redundant work. We will print a warning
if this method of sharding is selected.
You can disable dataset distribution using the `auto_shard` option in
`tf.data.experimental.DistributeOptions`.
Within each host, we will also split the data among all the worker devices
(if more than one a present), and this will happen even if multi-worker
sharding is disabled using the method above.
The following is an example:
......@@ -443,7 +461,8 @@ class Strategy(object):
strategy = tf.distribute.MirroredStrategy()
# Create a dataset
dataset = dataset_ops.Dataset.range(10).batch(2)
dataset = dataset_ops.Dataset.TFRecordDataset([
"/a/1.tfr", "/a/2.tfr", "/a/3.tfr", /a/4.tfr"])
# Distribute that dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
......@@ -454,8 +473,8 @@ class Strategy(object):
```
Args:
dataset: `tf.data.Dataset` that will be distributed evenly across all
replicas.
dataset: `tf.data.Dataset` that will be sharded across all replicas using
the rules stated above.
Returns:
A `DistributedDataset` which returns inputs for each step of the
......
......@@ -296,6 +296,34 @@ class DistributedIteratorMultiWorkerTest(
])
]
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
autoshard=[True, False]))
def testAutoshardingOption(self, input_type, api_type, iteration_type,
autoshard):
ds_option = dataset_ops.Options()
ds_option.experimental_distribute.auto_shard = autoshard
worker_devices = self._cpu_devices()
with context.graph_mode(), self.cached_session() as sess:
if tf2.enabled():
dataset_fn = (
lambda _: dataset_ops.DatasetV2.range(4).with_options(ds_option))
else:
dataset_fn = (
lambda _: dataset_ops.Dataset.range(4).with_options(ds_option))
if autoshard:
expected_values = [[0, 1], [2, 3]]
else:
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
self._test_input_iteration(input_type, api_type, iteration_type,
dataset_fn, worker_devices,
expected_values, sess)
@combinations.generate(combinations.combine(
mode=["graph"],
input_type=["input_fn", "dataset"],
......
......@@ -42,10 +42,13 @@ def auto_shard_dataset(dataset, num_shards, index):
files. The input dataset will be returned if we cannot automatically
determine a good way to shard the input dataset.
"""
if isinstance(dataset, dataset_ops.DatasetV1):
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
if dataset.options().experimental_distribute.auto_shard:
if isinstance(dataset, dataset_ops.DatasetV1):
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
else:
return distribute._AutoShardDataset(dataset, num_shards, index)
else:
return distribute._AutoShardDataset(dataset, num_shards, index)
return dataset
def _clone_dataset(dataset):
......
......@@ -7,6 +7,10 @@ tf_class {
name: "experimental_deterministic"
mtype: "<type \'property\'>"
}
member {
name: "experimental_distribute"
mtype: "<type \'property\'>"
}
member {
name: "experimental_optimization"
mtype: "<type \'property\'>"
......
path: "tensorflow.data.experimental.DistributeOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "auto_shard"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -16,6 +16,10 @@ tf_module {
name: "DatasetStructure"
mtype: "<type \'type\'>"
}
member {
name: "DistributeOptions"
mtype: "<type \'type\'>"
}
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
......
......@@ -7,6 +7,10 @@ tf_class {
name: "experimental_deterministic"
mtype: "<type \'property\'>"
}
member {
name: "experimental_distribute"
mtype: "<type \'property\'>"
}
member {
name: "experimental_optimization"
mtype: "<type \'property\'>"
......
path: "tensorflow.data.experimental.DistributeOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.distribute_options.DistributeOptions\'>"
is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "auto_shard"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
......@@ -16,6 +16,10 @@ tf_module {
name: "DatasetStructure"
mtype: "<type \'type\'>"
}
member {
name: "DistributeOptions"
mtype: "<type \'type\'>"
}
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册