提交 dc2aa7a3 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Have public tf.data symbols alias V1 or V2 depending on TensorFlow...

[tf.data] Have public tf.data symbols alias V1 or V2 depending on TensorFlow version so that internal TensorFlow uses of the symbols return a matching version of those symbols.

PiperOrigin-RevId: 275107124
Change-Id: I5b62e309933b79ce68492f7c96026acbe3764222
上级 37b111bb
...@@ -19,7 +19,12 @@ from __future__ import division ...@@ -19,7 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.experimental.ops import readers as exp_readers
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import control_flow_v2_toggles
...@@ -50,8 +55,20 @@ def enable_v2_behavior(): ...@@ -50,8 +55,20 @@ def enable_v2_behavior():
ops.enable_tensor_equality() ops.enable_tensor_equality()
# Enables TensorArrayV2 and control flow V2. # Enables TensorArrayV2 and control flow V2.
control_flow_v2_toggles.enable_control_flow_v2() control_flow_v2_toggles.enable_control_flow_v2()
# Make sure internal uses of the `dataset_ops.Dataset` map to DatasetV2. # Make sure internal uses of tf.data symbols map to V2 versions.
dataset_ops.Dataset = dataset_ops.DatasetV2 dataset_ops.Dataset = dataset_ops.DatasetV2
readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV2
readers.TFRecordDataset = readers.TFRecordDatasetV2
readers.TextLineDataset = readers.TextLineDatasetV2
counter.Counter = counter.CounterV2
interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v2
interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v2
random_ops.RandomDataset = random_ops.RandomDatasetV2
exp_readers.CsvDataset = exp_readers.CsvDatasetV2
exp_readers.SqlDataset = exp_readers.SqlDatasetV2
exp_readers.make_batched_features_dataset = (
exp_readers.make_batched_features_dataset_v2)
exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v2
@tf_export(v1=["disable_v2_behavior"]) @tf_export(v1=["disable_v2_behavior"])
...@@ -72,5 +89,17 @@ def disable_v2_behavior(): ...@@ -72,5 +89,17 @@ def disable_v2_behavior():
ops.disable_tensor_equality() ops.disable_tensor_equality()
# Disables TensorArrayV2 and control flow V2. # Disables TensorArrayV2 and control flow V2.
control_flow_v2_toggles.disable_control_flow_v2() control_flow_v2_toggles.disable_control_flow_v2()
# Make sure internal uses of the `dataset_ops.Dataset` map to DatasetV1. # Make sure internal uses of tf.data symbols map to V1 versions.
dataset_ops.Dataset = dataset_ops.DatasetV1 dataset_ops.Dataset = dataset_ops.DatasetV1
readers.FixedLengthRecordDataset = readers.FixedLengthRecordDatasetV1
readers.TFRecordDataset = readers.TFRecordDatasetV1
readers.TextLineDataset = readers.TextLineDatasetV1
counter.Counter = counter.CounterV1
interleave_ops.choose_from_datasets = interleave_ops.choose_from_datasets_v1
interleave_ops.sample_from_datasets = interleave_ops.sample_from_datasets_v1
random_ops.RandomDataset = random_ops.RandomDatasetV1
exp_readers.CsvDataset = exp_readers.CsvDatasetV1
exp_readers.SqlDataset = exp_readers.SqlDatasetV1
exp_readers.make_batched_features_dataset = (
exp_readers.make_batched_features_dataset_v1)
exp_readers.make_csv_dataset = exp_readers.make_csv_dataset_v1
...@@ -17,8 +17,8 @@ from __future__ import absolute_import ...@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import scan_ops from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
...@@ -60,6 +60,7 @@ def CounterV1(start=0, step=1, dtype=dtypes.int64): ...@@ -60,6 +60,7 @@ def CounterV1(start=0, step=1, dtype=dtypes.int64):
return dataset_ops.DatasetV1Adapter(CounterV2(start, step, dtype)) return dataset_ops.DatasetV1Adapter(CounterV2(start, step, dtype))
CounterV1.__doc__ = CounterV2.__doc__ CounterV1.__doc__ = CounterV2.__doc__
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep if tf2.enabled():
# this alias in place. Counter = CounterV2
Counter = CounterV1 # pylint: disable=invalid-name else:
Counter = CounterV1
...@@ -17,6 +17,7 @@ from __future__ import absolute_import ...@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import random_ops from tensorflow.python.data.experimental.ops import random_ops
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers from tensorflow.python.data.ops import readers
...@@ -280,7 +281,9 @@ def choose_from_datasets_v1(datasets, choice_dataset): ...@@ -280,7 +281,9 @@ def choose_from_datasets_v1(datasets, choice_dataset):
choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__ choose_from_datasets_v1.__doc__ = choose_from_datasets_v2.__doc__
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep if tf2.enabled():
# these aliases in place. choose_from_datasets = choose_from_datasets_v2
choose_from_datasets = choose_from_datasets_v1 sample_from_datasets = sample_from_datasets_v2
sample_from_datasets = sample_from_datasets_v1 else:
choose_from_datasets = choose_from_datasets_v1
sample_from_datasets = sample_from_datasets_v1
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import functools import functools
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import random_seed from tensorflow.python.data.util import random_seed
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
...@@ -53,6 +54,7 @@ class RandomDatasetV1(dataset_ops.DatasetV1Adapter): ...@@ -53,6 +54,7 @@ class RandomDatasetV1(dataset_ops.DatasetV1Adapter):
super(RandomDatasetV1, self).__init__(wrapped) super(RandomDatasetV1, self).__init__(wrapped)
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep if tf2.enabled():
# this alias in place. RandomDataset = RandomDatasetV2
RandomDataset = RandomDatasetV1 else:
RandomDataset = RandomDatasetV1
...@@ -24,6 +24,7 @@ import gzip ...@@ -24,6 +24,7 @@ import gzip
import numpy as np import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import error_ops from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import parsing_ops from tensorflow.python.data.experimental.ops import parsing_ops
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
...@@ -750,7 +751,7 @@ class CsvDatasetV1(dataset_ops.DatasetV1Adapter): ...@@ -750,7 +751,7 @@ class CsvDatasetV1(dataset_ops.DatasetV1Adapter):
def make_batched_features_dataset_v2(file_pattern, def make_batched_features_dataset_v2(file_pattern,
batch_size, batch_size,
features, features,
reader=core_readers.TFRecordDataset, reader=None,
label_key=None, label_key=None,
reader_args=None, reader_args=None,
num_epochs=None, num_epochs=None,
...@@ -852,6 +853,8 @@ def make_batched_features_dataset_v2(file_pattern, ...@@ -852,6 +853,8 @@ def make_batched_features_dataset_v2(file_pattern,
TypeError: If `reader` is a `tf.compat.v1.ReaderBase` subclass. TypeError: If `reader` is a `tf.compat.v1.ReaderBase` subclass.
ValueError: If `label_key` is not one of the `features` keys. ValueError: If `label_key` is not one of the `features` keys.
""" """
if reader is None:
reader = core_readers.TFRecordDataset
if reader_num_threads is None: if reader_num_threads is None:
reader_num_threads = 1 reader_num_threads = 1
...@@ -932,7 +935,7 @@ def make_batched_features_dataset_v2(file_pattern, ...@@ -932,7 +935,7 @@ def make_batched_features_dataset_v2(file_pattern,
def make_batched_features_dataset_v1(file_pattern, # pylint: disable=missing-docstring def make_batched_features_dataset_v1(file_pattern, # pylint: disable=missing-docstring
batch_size, batch_size,
features, features,
reader=core_readers.TFRecordDataset, reader=None,
label_key=None, label_key=None,
reader_args=None, reader_args=None,
num_epochs=None, num_epochs=None,
...@@ -1042,9 +1045,13 @@ class SqlDatasetV1(dataset_ops.DatasetV1Adapter): ...@@ -1042,9 +1045,13 @@ class SqlDatasetV1(dataset_ops.DatasetV1Adapter):
super(SqlDatasetV1, self).__init__(wrapped) super(SqlDatasetV1, self).__init__(wrapped)
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep if tf2.enabled():
# these aliases in place. CsvDataset = CsvDatasetV2
CsvDataset = CsvDatasetV1 SqlDataset = SqlDatasetV2
SqlDataset = SqlDatasetV1 make_batched_features_dataset = make_batched_features_dataset_v2
make_batched_features_dataset = make_batched_features_dataset_v1 make_csv_dataset = make_csv_dataset_v2
make_csv_dataset = make_csv_dataset_v1 else:
CsvDataset = CsvDatasetV1
SqlDataset = SqlDatasetV1
make_batched_features_dataset = make_batched_features_dataset_v1
make_csv_dataset = make_csv_dataset_v1
...@@ -17,6 +17,7 @@ from __future__ import absolute_import ...@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert from tensorflow.python.data.util import convert
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
...@@ -516,8 +517,11 @@ class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter): ...@@ -516,8 +517,11 @@ class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter):
self._dataset._filenames = value # pylint: disable=protected-access self._dataset._filenames = value # pylint: disable=protected-access
# TODO(b/119044825): Until all `tf.data` unit tests are converted to V2, keep if tf2.enabled():
# these aliases in place. FixedLengthRecordDataset = FixedLengthRecordDatasetV2
FixedLengthRecordDataset = FixedLengthRecordDatasetV1 TFRecordDataset = TFRecordDatasetV2
TFRecordDataset = TFRecordDatasetV1 TextLineDataset = TextLineDatasetV2
TextLineDataset = TextLineDatasetV1 else:
FixedLengthRecordDataset = FixedLengthRecordDatasetV1
TFRecordDataset = TFRecordDatasetV1
TextLineDataset = TextLineDatasetV1
...@@ -162,7 +162,7 @@ tf_module { ...@@ -162,7 +162,7 @@ tf_module {
} }
member_method { member_method {
name: "make_batched_features_dataset" name: "make_batched_features_dataset"
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV1\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], " argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
} }
member_method { member_method {
name: "make_csv_dataset" name: "make_csv_dataset"
......
...@@ -134,7 +134,7 @@ tf_module { ...@@ -134,7 +134,7 @@ tf_module {
} }
member_method { member_method {
name: "make_batched_features_dataset" name: "make_batched_features_dataset"
argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDatasetV1\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], " argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\'], "
} }
member_method { member_method {
name: "make_csv_dataset" name: "make_csv_dataset"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册