提交 7e47624f 编写于 作者: A Asim Shankar 提交者: TensorFlower Gardener

eager: Initial support for iteration over tf.contrib.data.Dataset objects.

TODO:
- Support function-valued operation attributes in eager
  (Required for MapDataset, FilterDataset etc. which encode the
  per-element computation in a TensorFlow function)
PiperOrigin-RevId: 168418250
上级 b0a397fc
...@@ -2,13 +2,14 @@ licenses(["notice"]) # Apache 2.0 ...@@ -2,13 +2,14 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"]) package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test", "cuda_py_test")
py_library( py_library(
name = "tfe", name = "tfe",
srcs = ["tfe.py"], srcs = ["tfe.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":datasets",
":saver", ":saver",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
...@@ -31,6 +32,34 @@ cuda_py_test( ...@@ -31,6 +32,34 @@ cuda_py_test(
], ],
) )
py_library(
name = "datasets",
srcs = ["datasets.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/util:nest",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:errors",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/eager:context",
],
)
py_test(
name = "datasets_test",
srcs = ["datasets_test.py"],
srcs_version = "PY2AND3",
deps = [
":datasets",
"//tensorflow/contrib/data",
"//tensorflow/python:math_ops",
"//tensorflow/python/eager:test",
"//third_party/py/numpy",
],
)
py_library( py_library(
name = "saver", name = "saver",
srcs = ["saver.py"], srcs = ["saver.py"],
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Support for tf.contrib.data when eager execution is enabled."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.contrib.data.python.util import nest
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
_uid_counter = 0
_uid_lock = threading.Lock()
def _iterator_shared_name():
with _uid_lock:
global _uid_counter
uid = _uid_counter
_uid_counter += 1
return "eager_iterator_{}".format(uid)
class Iterator(object):
"""An iterator producing tf.Tensor objects from a tf.contrib.data.Dataset."""
def __init__(self, dataset):
"""Creates a new iterator over the given dataset.
For example:
```python
dataset = tf.contrib.data.Dataset.range(4)
for x in Iterator(dataset):
print(x)
```
Args:
dataset: A `tf.contrib.data.Dataset` object.
Raises:
RuntimeError: When invoked without eager execution enabled.
"""
if not context.in_eager_mode():
raise RuntimeError(
"{} objects only make sense when eager execution is enabled".format(
type(self)))
ds_variant = dataset.make_dataset_resource()
self._output_types = dataset.output_types
self._flat_output_types = nest.flatten(dataset.output_types)
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
self._resource = gen_dataset_ops.iterator(
container="",
shared_name=_iterator_shared_name(),
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
def __del__(self):
if self._resource is not None:
resource_variable_ops.destroy_resource_op(self._resource)
self._resource = None
def __iter__(self):
return self
def __next__(self): # For Python 3 compatibility
return self.next()
def next(self):
"""Return the next tf.Tensor from the dataset."""
try:
ret = gen_dataset_ops.iterator_get_next(
self._resource,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
return nest.pack_sequence_as(self._output_types, ret)
except errors.OutOfRangeError:
raise StopIteration
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data import Dataset
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.eager import test
from tensorflow.python.ops import math_ops
class IteratorTest(test.TestCase):
def testBasic(self):
got = []
for t in datasets.Iterator(Dataset.range(4)):
got.append(t.numpy())
self.assertAllEqual([0, 1, 2, 3], got)
def testMultipleIteratorsOnTheSameDataset(self):
ds = Dataset.range(4)
it1 = datasets.Iterator(ds)
it2 = datasets.Iterator(ds)
got = [x.numpy() for x in it1]
self.assertAllEqual([0, 1, 2, 3], got)
got = [x.numpy() for x in it2]
self.assertAllEqual([0, 1, 2, 3], got)
def testNestedOutputs(self):
ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4),
Dataset.range(4)))))
total = 0
# The Iterator will return a nested structure of Tensor objects.
# Some funkiness to compare against simple integers.
for (i, x) in enumerate(datasets.Iterator(ds)):
want = (i, (i, i))
got = (x[0].numpy(), (x[1][0].numpy(), x[1][1].numpy()))
self.assertEqual(got, want)
total += 1
self.assertEqual(4, total)
def testMapAndFilter(self):
# TODO(ashankar): Address this
self.skipTest('Not working yet, requires function attribute support')
def even(x):
return math_ops.equal(math_ops.mod(x, 2), 0)
it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even))
got = [x.numpy() for x in it]
self.assertAllEqual([0, 4, 16, 36], got)
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册