提交 20e2cb97 编写于 作者: R Ruoxin Sang 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 328789639
上级 454e12b8
......@@ -75,3 +75,6 @@ def define_flags():
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
flags.DEFINE_string(
'tf_data_service', default=None, help='The tf.data service address')
......@@ -100,6 +100,7 @@ class InputReader:
self._cache = params.cache
self._cycle_length = params.cycle_length
self._block_length = params.block_length
self._deterministic = params.deterministic
self._sharding = params.sharding
self._examples_consume = params.examples_consume
self._tfds_split = params.tfds_split
......@@ -114,6 +115,11 @@ class InputReader:
self._postprocess_fn = postprocess_fn
self._seed = _get_random_integer()
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files(
self,
input_context: Optional[tf.distribute.InputContext] = None):
......@@ -134,8 +140,11 @@ class InputReader:
seed=self._seed,
reshuffle_each_iteration=True)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
input_context.num_input_pipelines > 1 and
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
......@@ -145,7 +154,8 @@ class InputReader:
map_func=self._dataset_fn,
cycle_length=self._cycle_length,
block_length=self._block_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
num_parallel_calls=tf.data.experimental.AUTOTUNE,
deterministic=self._deterministic)
return dataset
def _read_single_file(
......@@ -161,8 +171,11 @@ class InputReader:
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if self._sharding and input_context and (
input_context.num_input_pipelines > 1):
input_context.num_input_pipelines > 1 and
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
......@@ -243,4 +256,18 @@ class InputReader:
per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn)
if self._enable_tf_data_service:
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
if self._deterministic is not None:
options = tf.data.Options()
options.experimental_deterministic = self._deterministic
dataset = dataset.with_options(options)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
......@@ -60,10 +60,18 @@ def parse_configuration(flags_obj):
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
# 3. Override the TPU address.
# 3. Override the TPU address and tf.data service address.
params.override({
'runtime': {
'tpu': flags_obj.tpu,
},
'task': {
'train_data': {
'tf_data_service_address': flags_obj.tf_data_service,
},
'validation_data': {
'tf_data_service_address': flags_obj.tf_data_service,
}
}
})
......
......@@ -48,11 +48,22 @@ class DataConfig(base_config.Config):
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
......@@ -74,8 +85,12 @@ class DataConfig(base_config.Config):
cache: bool = False
cycle_length: int = 8
block_length: int = 1
deterministic: Optional[bool] = None
sharding: bool = True
examples_consume: int = -1
enable_tf_data_service: bool = False
tf_data_service_address: Optional[str] = None
tf_data_service_job_name: Optional[str] = None
tfds_data_dir: str = ""
tfds_download: bool = False
tfds_as_supervised: bool = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册