提交 dc2aa7a3 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Have public tf.data symbols alias V1 or V2 depending on TensorFlow...

[tf.data] Have public tf.data symbols alias V1 or V2 depending on TensorFlow version so that internal TensorFlow uses of the symbols return a matching version of those symbols.

PiperOrigin-RevId: 275107124
Change-Id: I5b62e309933b79ce68492f7c96026acbe3764222
上级 37b111bb
......@@ -19,7 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.experimental.ops import readers as exp_readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_v2_toggles
......@@ -50,8 +55,20 @@ def enable_v2_behavior():
ops.enable_tensor_equality()
# Enables TensorArrayV2 and control flow V2.
control_flow_v2_toggles.enable_control_flow_v2()
# Make sure internal uses of the `dataset_ops.Dataset` map to DatasetV2.
# Make sure internal uses of tf.data symbols map to V2 versions.
dataset_ops.Dataset = dataset_ops.DatasetV2
readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV2
readers.TFRecordDataset = readers.TFRecordDatasetV2
readers.TextLineDataset = readers.TextLineDatasetV2
counter.Counter = counter.CounterV2
interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v2
interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v2
random_ops.RandomDataset = random_ops.RandomDatasetV2
exp_readers.CsvDataset = exp_readers.CsvDatasetV2
exp_readers.SqlDataset = exp_readers.SqlDatasetV2
exp_readers.make_batched_features_dataset = (
exp_readers.make_batched_features_dataset_v2)
exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v2
@tf_export(v1=["disable_v2_behavior"])
......@@ -72,5 +89,17 @@ def disable_v2_behavior():
ops.disable_tensor_equality()
# Disables TensorArrayV2 and control flow V2.
control_flow_v2_toggles.disable_control_flow_v2()
# Make sure internal uses of the `dataset_ops.Dataset` map to DatasetV1.
# Make sure internal uses of tf.data symbols map to V1 versions.
dataset_ops.Dataset = dataset_ops.DatasetV1
readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV1
readers.TFRecordDataset = readers.TFRecordDatasetV1
readers.TextLineDataset = readers.TextLineDatasetV1
counter.Counter = counter.CounterV1
interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v1
interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v1
random_ops.RandomDataset = random_ops.RandomDatasetV1
exp_readers.CsvDataset = exp_readers.CsvDatasetV1
exp_readers.SqlDataset = exp_readers.SqlDatasetV1
exp_readers.make_batched_features_dataset = (
exp_readers.make_batched_features_dataset_v1)
exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v1
......@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -60,6 +60,7 @@ def CounterV1(start=0, step=1, dtype=dtypes.int64):
return dataset_ops.DatasetV1Adapter(CounterV2(start, step, dtype))
CounterV1.__doc__ = CounterV2.__doc__
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
# this alias in place.
Counter = CounterV1 # pylint: disable=invalid-name
if tf2.enabled():
Counter = CounterV2
else:
Counter = CounterV1
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
......@@ -280,7 +281,9 @@ def choose_from_datasets_v1(datasets, choice_dataset):
choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
# these aliases in place.
choose_from_datasets = choose_from_datasets_v1
sample_from_datasets = sample_from_datasets_v1
if tf2.enabled():
choose_from_datasets = choose_from_datasets_v2
sample_from_datasets = sample_from_datasets_v2
else:
choose_from_datasets = choose_from_datasets_v1
sample_from_datasets = sample_from_datasets_v1
......@@ -19,6 +19,7 @@ from __future__ import print_function
import functools
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import random_seed
from tensorflow.python.framework import dtypes
......@@ -53,6 +54,7 @@ class RandomDatasetV1(dataset_ops.DatasetV1Adapter):
super(RandomDatasetV1, self).__init__(wrapped)
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
# this alias in place.
RandomDataset = RandomDatasetV1
if tf2.enabled():
RandomDataset = RandomDatasetV2
else:
RandomDataset = RandomDatasetV1
......@@ -24,6 +24,7 @@ import gzip
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import parsing_ops
from tensorflow.python.data.ops import dataset_ops
......@@ -750,7 +751,7 @@ class CsvDatasetV1(dataset_ops.DatasetV1Adapter):
def make_batched_features_dataset_v2(file_pattern,
batch_size,
features,
reader=core_readers.TFRecordDataset,
reader=None,
label_key=None,
reader_args=None,
num_epochs=None,
......@@ -852,6 +853,8 @@ def make_batched_features_dataset_v2(file_pattern,
TypeError: If `reader` is a `tf.compat.v1.ReaderBase` subclass.
ValueError: If `label_key` is not one of the `features` keys.
"""
if reader is None:
reader = core_readers.TFRecordDataset
if reader_num_threads is None:
reader_num_threads = 1
......@@ -932,7 +935,7 @@ def make_batched_features_dataset_v2(file_pattern,
def make_batched_features_dataset_v1(file_pattern, # pylint: disable=missing-docstring
batch_size,
features,
reader=core_readers.TFRecordDataset,
reader=None,
label_key=None,
reader_args=None,
num_epochs=None,
......@@ -1042,9 +1045,13 @@ class SqlDatasetV1(dataset_ops.DatasetV1Adapter):
super(SqlDatasetV1, self).__init__(wrapped)
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
# these aliases in place.
CsvDataset = CsvDatasetV1
SqlDataset = SqlDatasetV1
make_batched_features_dataset = make_batched_features_dataset_v1
make_csv_dataset = make_csv_dataset_v1
if tf2.enabled():
CsvDataset = CsvDatasetV2
SqlDataset = SqlDatasetV2
make_batched_features_dataset = make_batched_features_dataset_v2
make_csv_dataset = make_csv_dataset_v2
else:
CsvDataset = CsvDatasetV1
SqlDataset = SqlDatasetV1
make_batched_features_dataset = make_batched_features_dataset_v1
make_csv_dataset = make_csv_dataset_v1
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes
......@@ -516,8 +517,11 @@ class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter):
self._dataset._filenames = value # pylint: disable=protected-access
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep
# these aliases in place.
FixedLengthRecordDataset = FixedLengthRecordDatasetV1
TFRecordDataset = TFRecordDatasetV1
TextLineDataset = TextLineDatasetV1
if tf2.enabled():
FixedLengthRecordDataset = FixedLengthRecordDatasetV2
TFRecordDataset = TFRecordDatasetV2
TextLineDataset = TextLineDatasetV2
else:
FixedLengthRecordDataset = FixedLengthRecordDatasetV1
TFRecordDataset = TFRecordDatasetV1
TextLineDataset = TextLineDatasetV1
......@@ -162,7 +162,7 @@ tf_module {
}
member_method {
name: "make_batched_features_dataset"
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV1\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
}
member_method {
name: "make_csv_dataset"
......
......@@ -134,7 +134,7 @@ tf_module {
}
member_method {
name: "make_batched_features_dataset"
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV1\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
}
member_method {
name: "make_csv_dataset"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册