Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
e8f8e219
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e8f8e219
编写于
12月 13, 2018
作者:
S
Shivani Agrawal
提交者:
TensorFlower Gardener
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[tf.data] Adds coverage for data experimental tests.
PiperOrigin-RevId: 225403946
上级
976b23c9
变更
31
展开全部
隐藏空白更改
内联
并排
Showing
31 changed file
with
862 addition
and
1047 deletion
+862
-1047
tensorflow/python/data/experimental/benchmarks/BUILD
tensorflow/python/data/experimental/benchmarks/BUILD
+30
-0
tensorflow/python/data/experimental/benchmarks/map_defun_benchmark.py
...ython/data/experimental/benchmarks/map_defun_benchmark.py
+73
-0
tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
...python/data/experimental/benchmarks/optimize_benchmark.py
+41
-0
tensorflow/python/data/experimental/benchmarks/rejection_resample_benchmark.py
...a/experimental/benchmarks/rejection_resample_benchmark.py
+71
-0
tensorflow/python/data/experimental/benchmarks/unbatch_benchmark.py
.../python/data/experimental/benchmarks/unbatch_benchmark.py
+1
-1
tensorflow/python/data/experimental/kernel_tests/BUILD
tensorflow/python/data/experimental/kernel_tests/BUILD
+4
-24
tensorflow/python/data/experimental/kernel_tests/cardinality_test.py
...python/data/experimental/kernel_tests/cardinality_test.py
+2
-0
tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
...hon/data/experimental/kernel_tests/copy_to_device_test.py
+1
-0
tensorflow/python/data/experimental/kernel_tests/counter_test.py
...low/python/data/experimental/kernel_tests/counter_test.py
+16
-20
tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
...rimental/kernel_tests/directed_interleave_dataset_test.py
+18
-28
tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
.../data/experimental/kernel_tests/enumerate_dataset_test.py
+8
-17
tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
...data/experimental/kernel_tests/get_single_element_test.py
+13
-22
tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
...n/data/experimental/kernel_tests/group_by_reducer_test.py
+26
-37
tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
...on/data/experimental/kernel_tests/group_by_window_test.py
+154
-213
tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
...thon/data/experimental/kernel_tests/ignore_errors_test.py
+33
-47
tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
...ython/data/experimental/kernel_tests/map_defun_op_test.py
+1
-42
tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
...hon/data/experimental/kernel_tests/matching_files_test.py
+43
-70
tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
...ata/experimental/kernel_tests/override_threadpool_test.py
+3
-6
tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
...data/experimental/kernel_tests/prefetch_to_device_test.py
+1
-0
tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
...data/experimental/kernel_tests/rejection_resample_test.py
+21
-67
tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
...ta/experimental/kernel_tests/restructured_dataset_test.py
+1
-0
tensorflow/python/data/experimental/kernel_tests/scan_test.py
...orflow/python/data/experimental/kernel_tests/scan_test.py
+40
-62
tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
...data/experimental/kernel_tests/shuffle_and_repeat_test.py
+11
-20
tensorflow/python/data/experimental/kernel_tests/sleep_test.py
...rflow/python/data/experimental/kernel_tests/sleep_test.py
+9
-13
tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
.../data/experimental/kernel_tests/stats_dataset_ops_test.py
+157
-199
tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
...data/experimental/kernel_tests/stats_dataset_test_base.py
+17
-22
tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
...n/data/experimental/kernel_tests/tf_record_writer_test.py
+14
-33
tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
...low/python/data/experimental/kernel_tests/unbatch_test.py
+34
-83
tensorflow/python/data/experimental/kernel_tests/unique_test.py
...flow/python/data/experimental/kernel_tests/unique_test.py
+7
-13
tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
...python/data/experimental/kernel_tests/wrap_unwrap_test.py
+7
-8
tensorflow/python/data/kernel_tests/test_base.py
tensorflow/python/data/kernel_tests/test_base.py
+5
-0
未找到文件。
tensorflow/python/data/experimental/benchmarks/BUILD
浏览文件 @
e8f8e219
...
...
@@ -58,6 +58,22 @@ py_test(
],
)
py_test
(
name
=
"map_defun_benchmark"
,
srcs
=
[
"map_defun_benchmark.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow/python:array_ops"
,
"//tensorflow/python:client_testlib"
,
"//tensorflow/python:dtypes"
,
"//tensorflow/python:functional_ops"
,
"//tensorflow/python:math_ops"
,
"//tensorflow/python:tensor_spec"
,
"//tensorflow/python/data/experimental/ops:map_defun"
,
"//tensorflow/python/eager:function"
,
],
)
py_test
(
name
=
"map_vectorization_benchmark"
,
srcs
=
[
"map_vectorization_benchmark.py"
],
...
...
@@ -108,6 +124,20 @@ py_test(
],
)
py_test
(
name
=
"rejection_resample_benchmark"
,
srcs
=
[
"rejection_resample_benchmark.py"
],
srcs_version
=
"PY2AND3"
,
tags
=
[
"no_pip"
],
deps
=
[
"//tensorflow/python:client_testlib"
,
"//tensorflow/python/data/experimental/ops:resampling"
,
"//tensorflow/python/data/ops:dataset_ops"
,
"//third_party/py/numpy"
,
"@six_archive//:six"
,
],
)
py_test
(
name
=
"unbatch_benchmark"
,
srcs
=
[
"unbatch_benchmark.py"
],
...
...
tensorflow/python/data/experimental/
kernel_tests/filter_dataset_op_test
.py
→
tensorflow/python/data/experimental/
benchmarks/map_defun_benchmark
.py
浏览文件 @
e8f8e219
...
...
@@ -12,64 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmarks
FilterDataset input pipeline o
p."""
"""Benchmarks
for MapDefunO
p."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
numpy
as
np
from
tensorflow.python.
client
import
session
from
tensorflow.python.
data.experimental.ops
import
optimization
from
tensorflow.python.
data.ops
import
dataset
_ops
from
tensorflow.python.
framework
import
ops
from
tensorflow.python.data.experimental.ops
import
map_defun
from
tensorflow.python.eager
import
function
from
tensorflow.python.
framework
import
dtypes
from
tensorflow.python.
framework
import
tensor_spec
from
tensorflow.python.
ops
import
array
_ops
from
tensorflow.python.
ops
import
functional_
ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.platform
import
test
class
FilterBenchmark
(
test
.
Benchmark
):
# TODO(b/119837791): Add eager benchmarks too.
class
MapDefunBenchmark
(
test
.
Benchmark
):
"""Benchmarks for MapDefunOp."""
def
_run
(
self
,
op
,
name
=
None
,
num_iters
=
3000
):
for
_
in
range
(
5
):
self
.
evaluate
(
op
)
start
=
time
.
time
()
for
_
in
range
(
num_iters
):
self
.
evaluate
(
op
)
end
=
time
.
time
()
mean_us
=
(
end
-
start
)
*
1e6
/
num_iters
self
.
report_benchmark
(
name
=
name
,
iters
=
num_iters
,
wall_time
=
mean_us
,
extras
=
{
"examples_per_sec"
:
num_iters
/
(
end
-
start
)})
# This benchmark compares the performance of pipeline with multiple chained
# filter with and without filter fusion.
def
benchmarkFilters
(
self
):
chain_lengths
=
[
0
,
1
,
2
,
5
,
10
,
20
,
50
]
for
chain_length
in
chain_lengths
:
self
.
_benchmarkFilters
(
chain_length
,
False
)
self
.
_benchmarkFilters
(
chain_length
,
True
)
def
benchmarkDefunVsMapFn
(
self
):
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
def
_benchmarkFilters
(
self
,
chain_length
,
optimize_dataset
):
with
ops
.
Graph
().
as_default
():
dataset
=
dataset_ops
.
Dataset
.
from_tensors
(
5
).
repeat
(
None
)
for
_
in
range
(
chain_length
):
dataset
=
dataset
.
filter
(
lambda
x
:
math_ops
.
greater_equal
(
x
-
5
,
0
))
if
optimize_dataset
:
dataset
=
dataset
.
apply
(
optimization
.
optimize
([
"filter_fusion"
]))
@
function
.
defun
(
input_signature
=
[
tensor_spec
.
TensorSpec
([],
dtypes
.
int32
)])
def
defun
(
x
):
return
array_ops
.
identity
(
x
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
next_element
=
iterator
.
get_next
(
)
def
map_fn
(
x
):
return
array_ops
.
identity
(
x
)
with
session
.
Session
()
as
sess
:
for
_
in
range
(
10
):
self
.
evaluate
(
next_element
.
op
)
deltas
=
[]
for
_
in
range
(
100
):
start
=
time
.
time
()
for
_
in
range
(
100
):
self
.
evaluate
(
next_element
.
op
)
end
=
time
.
time
()
deltas
.
append
(
end
-
start
)
base
=
math_ops
.
range
(
100
)
for
input_size
in
[
10
,
100
,
1000
,
10000
]:
num_iters
=
100000
//
input_size
map_defun_op
=
map_defun
.
map_defun
(
defun
,
[
base
],
[
dtypes
.
int32
],
[()])
map_fn_op
=
functional_ops
.
map_fn
(
map_fn
,
base
)
median_wall_time
=
np
.
median
(
deltas
)
/
100
opt_mark
=
"opt"
if
optimize_dataset
else
"no-opt"
print
(
"Filter dataset {} chain length: {} Median wall time: {}"
.
format
(
opt_mark
,
chain_length
,
median_wall_time
))
self
.
report_benchmark
(
iters
=
1000
,
wall_time
=
median_wall_time
,
name
=
"benchmark_filter_dataset_chain_latency_{}_{}"
.
format
(
opt_mark
,
chain_length
))
self
.
_run
(
map_defun_op
,
"with_defun_size_%d"
%
input_size
,
num_iters
=
num_iters
)
self
.
_run
(
map_fn_op
,
"without_defun_size_%d"
%
input_size
,
num_iters
=
num_iters
)
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
浏览文件 @
e8f8e219
...
...
@@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops
from
tensorflow.python.platform
import
test
# TODO(b/119837791): Add eager benchmarks too.
class
OptimizationBenchmark
(
test
.
Benchmark
):
"""Benchmarks for static optimizations."""
...
...
@@ -115,6 +116,46 @@ class OptimizationBenchmark(test.Benchmark):
name
=
"map_and_filter_fusion_{}_chain_length_{}"
.
format
(
opt_mark
,
chain_length
))
# This benchmark compares the performance of pipeline with multiple chained
# filter with and without filter fusion.
def
benchmarkFilterFusion
(
self
):
chain_lengths
=
[
0
,
1
,
2
,
5
,
10
,
20
,
50
]
for
chain_length
in
chain_lengths
:
self
.
_benchmarkFilters
(
chain_length
,
False
)
self
.
_benchmarkFilters
(
chain_length
,
True
)
def
_benchmarkFilterFusion
(
self
,
chain_length
,
optimize_dataset
):
with
ops
.
Graph
().
as_default
():
dataset
=
dataset_ops
.
Dataset
.
from_tensors
(
5
).
repeat
(
None
)
for
_
in
range
(
chain_length
):
dataset
=
dataset
.
filter
(
lambda
x
:
math_ops
.
greater_equal
(
x
-
5
,
0
))
if
optimize_dataset
:
options
=
dataset_ops
.
Options
()
options
.
experimental_filter_fusion
=
True
dataset
=
dataset
.
with_options
(
options
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
for
_
in
range
(
10
):
self
.
evaluate
(
next_element
.
op
)
deltas
=
[]
for
_
in
range
(
100
):
start
=
time
.
time
()
for
_
in
range
(
100
):
self
.
evaluate
(
next_element
.
op
)
end
=
time
.
time
()
deltas
.
append
(
end
-
start
)
median_wall_time
=
np
.
median
(
deltas
)
/
100
opt_mark
=
"opt"
if
optimize_dataset
else
"no-opt"
print
(
"Filter dataset {} chain length: {} Median wall time: {}"
.
format
(
opt_mark
,
chain_length
,
median_wall_time
))
self
.
report_benchmark
(
iters
=
1000
,
wall_time
=
median_wall_time
,
name
=
"chain_length_{}_{}"
.
format
(
opt_mark
,
chain_length
))
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/data/experimental/benchmarks/rejection_resample_benchmark.py
0 → 100644
浏览文件 @
e8f8e219
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmarks for `tf.data.experimental.rejection_resample()`."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
import
numpy
as
np
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
tensorflow.python.data.experimental.ops
import
resampling
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.platform
import
test
def
_time_resampling
(
test_obj
,
data_np
,
target_dist
,
init_dist
,
num_to_sample
):
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
data_np
).
repeat
()
# Reshape distribution via rejection sampling.
dataset
=
dataset
.
apply
(
resampling
.
rejection_resample
(
class_func
=
lambda
x
:
x
,
target_dist
=
target_dist
,
initial_dist
=
init_dist
,
seed
=
142
))
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
test_obj
.
test_session
()
as
sess
:
start_time
=
time
.
time
()
for
_
in
xrange
(
num_to_sample
):
sess
.
run
(
get_next
)
end_time
=
time
.
time
()
return
end_time
-
start_time
class
RejectionResampleBenchmark
(
test
.
Benchmark
):
"""Benchmarks for `tf.data.experimental.rejection_resample()`."""
def
benchmarkResamplePerformance
(
self
):
init_dist
=
[
0.25
,
0.25
,
0.25
,
0.25
]
target_dist
=
[
0.0
,
0.0
,
0.0
,
1.0
]
num_classes
=
len
(
init_dist
)
# We don't need many samples to test a dirac-delta target distribution
num_samples
=
1000
data_np
=
np
.
random
.
choice
(
num_classes
,
num_samples
,
p
=
init_dist
)
resample_time
=
_time_resampling
(
self
,
data_np
,
target_dist
,
init_dist
,
num_to_sample
=
1000
)
self
.
report_benchmark
(
iters
=
1000
,
wall_time
=
resample_time
,
name
=
"resample"
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/data/experimental/benchmarks/unbatch_benchmark.py
浏览文件 @
e8f8e219
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test
s for `tf.data.experimental.unbatch()`."""
"""
Benchmark
s for `tf.data.experimental.unbatch()`."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
...
...
tensorflow/python/data/experimental/kernel_tests/BUILD
浏览文件 @
e8f8e219
load
(
"//tensorflow:tensorflow.bzl"
,
"py_test"
)
load
(
"//tensorflow:tensorflow.bzl"
,
"cuda_py_test"
)
package
(
default_visibility
=
[
"//tensorflow:internal"
])
licenses
([
"notice"
])
# Apache 2.0
exports_files
([
"LICENSE"
])
load
(
"//tensorflow:tensorflow.bzl"
,
"cuda_py_test"
)
load
(
"//tensorflow:tensorflow.bzl"
,
"py_test"
)
py_test
(
name
=
"bucket_by_sequence_length_test"
,
size
=
"medium"
,
...
...
@@ -129,26 +129,6 @@ py_test(
],
)
py_test
(
name
=
"filter_dataset_op_test"
,
size
=
"medium"
,
srcs
=
[
"filter_dataset_op_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
"//tensorflow/python:array_ops"
,
"//tensorflow/python:client_testlib"
,
"//tensorflow/python:errors"
,
"//tensorflow/python:framework_ops"
,
"//tensorflow/python:io_ops"
,
"//tensorflow/python:math_ops"
,
"//tensorflow/python:util"
,
"//tensorflow/python/data/experimental/ops:optimization"
,
"//tensorflow/python/data/kernel_tests:test_base"
,
"//tensorflow/python/data/ops:dataset_ops"
,
"//third_party/py/numpy"
,
],
)
py_test
(
name
=
"get_single_element_test"
,
size
=
"small"
,
...
...
@@ -622,7 +602,7 @@ py_test(
py_test
(
name
=
"stats_dataset_ops_test"
,
size
=
"
medium
"
,
size
=
"
large
"
,
srcs
=
[
"stats_dataset_ops_test.py"
],
srcs_version
=
"PY2AND3"
,
tags
=
[
...
...
tensorflow/python/data/experimental/kernel_tests/cardinality_test.py
浏览文件 @
e8f8e219
...
...
@@ -22,9 +22,11 @@ from absl.testing import parameterized
from
tensorflow.python.data.experimental.ops
import
cardinality
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
NumElementsTest
(
test_base
.
DatasetTestBase
,
parameterized
.
TestCase
):
"""Tests for `tf.data.experimental.cardinality()`."""
...
...
tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
浏览文件 @
e8f8e219
...
...
@@ -33,6 +33,7 @@ from tensorflow.python.platform import test
from
tensorflow.python.util
import
compat
as
util_compat
# TODO(b/119837791): add eager coverage when supported.
class
CopyToDeviceTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
...
...
tensorflow/python/data/experimental/kernel_tests/counter_test.py
浏览文件 @
e8f8e219
...
...
@@ -19,35 +19,31 @@ from __future__ import print_function
from
tensorflow.python.data.experimental.ops
import
counter
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
CounterTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
def
testCounter
(
self
):
"""Test dataset construction using `count`."""
iterator
=
dataset_ops
.
make_one_shot_iterator
(
counter
.
Counter
(
start
=
3
,
step
=
4
))
get_next
=
iterator
.
get_next
()
self
.
assertEqual
([],
get_next
.
shape
.
as_list
())
self
.
assertEqual
(
dtypes
.
int64
,
get_next
.
dtype
)
negative_iterator
=
dataset_ops
.
make_one_shot_iterator
(
counter
.
Counter
(
start
=
0
,
step
=-
1
))
negative_get_next
=
negative_iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
assertEqual
(
3
,
self
.
evaluate
(
get_next
))
self
.
assertEqual
(
3
+
4
,
self
.
evaluate
(
get_next
))
self
.
assertEqual
(
3
+
2
*
4
,
self
.
evaluate
(
get_next
))
self
.
assertEqual
(
0
,
self
.
evaluate
(
negative_get_next
))
self
.
assertEqual
(
-
1
,
self
.
evaluate
(
negative_get_next
))
self
.
assertEqual
(
-
2
,
self
.
evaluate
(
negative_get_next
))
dataset
=
counter
.
Counter
(
start
=
3
,
step
=
4
)
self
.
assertEqual
([],
dataset
.
output_shapes
.
as_list
())
self
.
assertEqual
(
dtypes
.
int64
,
dataset
.
output_types
)
get_next
=
self
.
getNext
(
dataset
)
negative_dataset
=
counter
.
Counter
(
start
=
0
,
step
=-
1
)
negative_get_next
=
self
.
getNext
(
negative_dataset
)
self
.
assertEqual
(
3
,
self
.
evaluate
(
get_next
()))
self
.
assertEqual
(
3
+
4
,
self
.
evaluate
(
get_next
()))
self
.
assertEqual
(
3
+
2
*
4
,
self
.
evaluate
(
get_next
()))
self
.
assertEqual
(
0
,
self
.
evaluate
(
negative_get_next
()))
self
.
assertEqual
(
-
1
,
self
.
evaluate
(
negative_get_next
()))
self
.
assertEqual
(
-
2
,
self
.
evaluate
(
negative_get_next
()))
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
浏览文件 @
e8f8e219
...
...
@@ -28,9 +28,9 @@ from tensorflow.python.framework import test_util
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
DirectedInterleaveDatasetTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
def
testBasic
(
self
):
selector_dataset
=
dataset_ops
.
Dataset
.
range
(
10
).
repeat
(
100
)
input_datasets
=
[
...
...
@@ -38,16 +38,13 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
]
dataset
=
interleave_ops
.
_DirectedInterleaveDataset
(
selector_dataset
,
input_datasets
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
next_element
=
self
.
getNext
(
dataset
)
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
iterator
.
initializer
)
for
_
in
range
(
100
):
for
i
in
range
(
10
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
next_element
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
for
_
in
range
(
100
):
for
i
in
range
(
10
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
next_element
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
def
_normalize
(
self
,
vec
):
return
vec
/
vec
.
sum
()
...
...
@@ -67,19 +64,16 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
for
i
in
range
(
num_datasets
)
],
weights
)
dataset
=
dataset
.
take
(
num_samples
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
freqs
=
np
.
zeros
([
num_datasets
])
for
_
in
range
(
num_samples
):
freqs
[
self
.
evaluate
(
next_element
)]
+=
1
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
next_element
=
self
.
getNext
(
dataset
)
freqs
=
np
.
zeros
([
num_datasets
])
for
_
in
range
(
num_samples
):
freqs
[
self
.
evaluate
(
next_element
()
)]
+=
1
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
()
)
return
freqs
@
test_util
.
run_deprecated_v1
def
testSampleFromDatasets
(
self
):
random_seed
.
set_random_seed
(
1619
)
num_samples
=
5000
...
...
@@ -99,21 +93,17 @@ class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
freqs
=
self
.
_testSampleFromDatasetsHelper
(
probs_ds
,
classes
,
num_samples
)
self
.
assertLess
(
self
.
_chi2
(
probs
,
freqs
/
num_samples
),
1e-2
)
@
test_util
.
run_deprecated_v1
def
testSelectFromDatasets
(
self
):
words
=
[
b
"foo"
,
b
"bar"
,
b
"baz"
]
datasets
=
[
dataset_ops
.
Dataset
.
from_tensors
(
w
).
repeat
()
for
w
in
words
]
choice_array
=
np
.
random
.
randint
(
3
,
size
=
(
15
,),
dtype
=
np
.
int64
)
choice_dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
choice_array
)
dataset
=
interleave_ops
.
choose_from_datasets
(
datasets
,
choice_dataset
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
i
in
choice_array
:
self
.
assertEqual
(
words
[
i
],
self
.
evaluate
(
next_element
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
next_element
=
self
.
getNext
(
dataset
)
for
i
in
choice_array
:
self
.
assertEqual
(
words
[
i
],
self
.
evaluate
(
next_element
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
def
testErrors
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
...
...
tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
浏览文件 @
e8f8e219
...
...
@@ -22,37 +22,28 @@ from tensorflow.python.data.kernel_tests import test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
tensor_shape
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
EnumerateDatasetTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
def
testEnumerateDataset
(
self
):
components
=
([
"a"
,
"b"
],
[
1
,
2
],
[
37.0
,
38
])
start
=
constant_op
.
constant
(
20
,
dtype
=
dtypes
.
int64
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
apply
(
enumerate_ops
.
enumerate_dataset
(
start
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
apply
(
enumerate_ops
.
enumerate_dataset
(
start
))
self
.
assertEqual
(
dtypes
.
int64
,
get_next
[
0
].
dtype
)
self
.
assertEqual
((),
get_next
[
0
].
shape
)
self
.
assertEqual
(
dtypes
.
int64
,
dataset
.
output_types
[
0
]
)
self
.
assertEqual
((),
dataset
.
output_shapes
[
0
]
)
self
.
assertEqual
([
tensor_shape
.
TensorShape
([])]
*
3
,
[
t
.
shape
for
t
in
get_next
[
1
]])
[
shape
for
shape
in
dataset
.
output_shapes
[
1
]])
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
self
.
assertEqual
((
20
,
(
b
"a"
,
1
,
37.0
)),
self
.
evaluate
(
get_next
))
self
.
assertEqual
((
21
,
(
b
"b"
,
2
,
38.0
)),
self
.
evaluate
(
get_next
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
self
.
assertDatasetProduces
(
dataset
,
[(
20
,
(
b
"a"
,
1
,
37.0
)),
(
21
,
(
b
"b"
,
2
,
38.0
))])
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
浏览文件 @
e8f8e219
...
...
@@ -22,7 +22,6 @@ from absl.testing import parameterized
from
tensorflow.python.data.experimental.ops
import
get_single_element
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
sparse_tensor
from
tensorflow.python.framework
import
test_util
...
...
@@ -30,6 +29,7 @@ from tensorflow.python.ops import array_ops
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
GetSingleElementTest
(
test_base
.
DatasetTestBase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
(
...
...
@@ -40,34 +40,25 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
(
"MoreThanOne"
,
0
,
2
,
errors
.
InvalidArgumentError
,
"Dataset had more than one element."
),
)
@
test_util
.
run_deprecated_v1
def
testGetSingleElement
(
self
,
skip
,
take
,
error
=
None
,
error_msg
=
None
):
skip_t
=
array_ops
.
placeholder
(
dtypes
.
int64
,
shape
=
[])
take_t
=
array_ops
.
placeholder
(
dtypes
.
int64
,
shape
=
[])
def
make_sparse
(
x
):
x_1d
=
array_ops
.
reshape
(
x
,
[
1
])
x_2d
=
array_ops
.
reshape
(
x
,
[
1
,
1
])
return
sparse_tensor
.
SparseTensor
(
x_2d
,
x_1d
,
x_1d
)
dataset
=
dataset_ops
.
Dataset
.
range
(
100
).
skip
(
skip_t
).
map
(
lambda
x
:
(
x
*
x
,
make_sparse
(
x
))).
take
(
take_t
)
element
=
get_single_element
.
get_single_element
(
dataset
)
with
self
.
cached_session
()
as
sess
:
if
error
is
None
:
dense_val
,
sparse_val
=
sess
.
run
(
element
,
feed_dict
=
{
skip_t
:
skip
,
take_t
:
take
})
self
.
assertEqual
(
skip
*
skip
,
dense_val
)
self
.
assertAllEqual
([[
skip
]],
sparse_val
.
indices
)
self
.
assertAllEqual
([
skip
],
sparse_val
.
values
)
self
.
assertAllEqual
([
skip
],
sparse_val
.
dense_shape
)
else
:
with
self
.
assertRaisesRegexp
(
error
,
error_msg
):
sess
.
run
(
element
,
feed_dict
=
{
skip_t
:
skip
,
take_t
:
take
})
dataset
=
dataset_ops
.
Dataset
.
range
(
100
).
skip
(
skip
).
map
(
lambda
x
:
(
x
*
x
,
make_sparse
(
x
))).
take
(
take
)
if
error
is
None
:
dense_val
,
sparse_val
=
self
.
evaluate
(
get_single_element
.
get_single_element
(
dataset
))
self
.
assertEqual
(
skip
*
skip
,
dense_val
)
self
.
assertAllEqual
([[
skip
]],
sparse_val
.
indices
)
self
.
assertAllEqual
([
skip
],
sparse_val
.
values
)
self
.
assertAllEqual
([
skip
],
sparse_val
.
dense_shape
)
else
:
with
self
.
assertRaisesRegexp
(
error
,
error_msg
):
self
.
evaluate
(
get_single_element
.
get_single_element
(
dataset
))
def
testWindow
(
self
):
"""Test that `get_single_element()` can consume a nested dataset."""
...
...
tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
浏览文件 @
e8f8e219
...
...
@@ -33,19 +33,9 @@ from tensorflow.python.ops import math_ops
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
GroupByReducerTest
(
test_base
.
DatasetTestBase
):
def
checkResults
(
self
,
dataset
,
shapes
,
values
):
self
.
assertEqual
(
shapes
,
dataset
.
output_shapes
)
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
self
.
cached_session
()
as
sess
:
for
expected
in
values
:
got
=
self
.
evaluate
(
get_next
)
self
.
assertEqual
(
got
,
expected
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
@
test_util
.
run_deprecated_v1
def
testSum
(
self
):
reducer
=
grouping
.
Reducer
(
init_func
=
lambda
_
:
np
.
int64
(
0
),
...
...
@@ -54,10 +44,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
for
i
in
range
(
1
,
11
):
dataset
=
dataset_ops
.
Dataset
.
range
(
2
*
i
).
apply
(
grouping
.
group_by_reducer
(
lambda
x
:
x
%
2
,
reducer
))
self
.
checkResults
(
dataset
,
shapes
=
tensor_shape
.
scalar
(),
values
=
[(
i
-
1
)
*
i
,
i
*
i
])
self
.
assertDatasetProduces
(
dataset
,
expected_shapes
=
tensor_shape
.
scalar
(),
expected_output
=
[(
i
-
1
)
*
i
,
i
*
i
])
@
test_util
.
run_deprecated_v1
def
testAverage
(
self
):
def
reduce_fn
(
x
,
y
):
...
...
@@ -72,10 +63,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
dataset
=
dataset_ops
.
Dataset
.
range
(
2
*
i
).
apply
(
grouping
.
group_by_reducer
(
lambda
x
:
math_ops
.
cast
(
x
,
dtypes
.
int64
)
%
2
,
reducer
))
self
.
checkResults
(
dataset
,
shapes
=
tensor_shape
.
scalar
(),
values
=
[
i
-
1
,
i
])
self
.
assertDatasetProduces
(
dataset
,
expected_shapes
=
tensor_shape
.
scalar
(),
expected_output
=
[
i
-
1
,
i
])
@
test_util
.
run_deprecated_v1
def
testConcat
(
self
):
components
=
np
.
array
(
list
(
"abcdefghijklmnopqrst"
)).
view
(
np
.
chararray
)
reducer
=
grouping
.
Reducer
(
...
...
@@ -87,12 +79,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
),
dataset_ops
.
Dataset
.
range
(
2
*
i
))).
apply
(
grouping
.
group_by_reducer
(
lambda
x
,
y
:
y
%
2
,
reducer
))
self
.
checkResult
s
(
self
.
assertDatasetProduce
s
(
dataset
,
shapes
=
tensor_shape
.
scalar
(),
values
=
[
b
"acegikmoqs"
[:
i
],
b
"bdfhjlnprt"
[:
i
]])
expected_
shapes
=
tensor_shape
.
scalar
(),
expected_output
=
[
b
"acegikmoqs"
[:
i
],
b
"bdfhjlnprt"
[:
i
]])
@
test_util
.
run_deprecated_v1
def
testSparseSum
(
self
):
def
_sparse
(
i
):
return
sparse_tensor
.
SparseTensorValue
(
...
...
@@ -107,10 +98,11 @@ class GroupByReducerTest(test_base.DatasetTestBase):
for
i
in
range
(
1
,
11
):
dataset
=
dataset_ops
.
Dataset
.
range
(
2
*
i
).
map
(
_sparse
).
apply
(
grouping
.
group_by_reducer
(
lambda
x
:
x
.
values
[
0
]
%
2
,
reducer
))
self
.
checkResults
(
dataset
,
shapes
=
tensor_shape
.
scalar
(),
values
=
[(
i
-
1
)
*
i
,
i
*
i
])
self
.
assertDatasetProduces
(
dataset
,
expected_shapes
=
tensor_shape
.
scalar
(),
expected_output
=
[(
i
-
1
)
*
i
,
i
*
i
])
@
test_util
.
run_deprecated_v1
def
testChangingStateShape
(
self
):
def
reduce_fn
(
x
,
_
):
...
...
@@ -130,14 +122,12 @@ class GroupByReducerTest(test_base.DatasetTestBase):
grouping
.
group_by_reducer
(
lambda
x
:
x
,
reducer
))
self
.
assertEqual
([
None
],
dataset
.
output_shapes
[
0
].
as_list
())
self
.
assertIs
(
None
,
dataset
.
output_shapes
[
1
].
ndims
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
x
,
y
=
self
.
evaluate
(
get_next
)
self
.
assertAllEqual
([
0
]
*
(
2
**
i
),
x
)
self
.
assertAllEqual
(
np
.
array
(
1
,
ndmin
=
i
),
y
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
get_next
=
self
.
getNext
(
dataset
)
x
,
y
=
self
.
evaluate
(
get_next
())
self
.
assertAllEqual
([
0
]
*
(
2
**
i
),
x
)
self
.
assertAllEqual
(
np
.
array
(
1
,
ndmin
=
i
),
y
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
def
testTypeMismatch
(
self
):
reducer
=
grouping
.
Reducer
(
...
...
@@ -194,11 +184,10 @@ class GroupByReducerTest(test_base.DatasetTestBase):
dataset
=
dataset_ops
.
Dataset
.
zip
(
(
dataset_ops
.
Dataset
.
range
(
10
),
dataset_ops
.
Dataset
.
range
(
10
))).
apply
(
grouping
.
group_by_reducer
(
lambda
x
,
y
:
np
.
int64
(
0
),
reducer
))
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
self
.
cached_session
()
as
sess
:
x
,
y
=
self
.
evaluate
(
get_next
)
self
.
assertAllEqual
(
x
,
np
.
asarray
([
x
for
x
in
range
(
10
)]))
self
.
assertEqual
(
y
,
45
)
get_next
=
self
.
getNext
(
dataset
)
x
,
y
=
self
.
evaluate
(
get_next
())
self
.
assertAllEqual
(
x
,
np
.
asarray
([
x
for
x
in
range
(
10
)]))
self
.
assertEqual
(
y
,
45
)
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
浏览文件 @
e8f8e219
...
...
@@ -37,6 +37,7 @@ from tensorflow.python.platform import test
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
@
test_util
.
run_all_in_graph_and_eager_modes
class
GroupByWindowTest
(
test_base
.
DatasetTestBase
):
def
_dynamicPad
(
self
,
bucket
,
window
,
window_size
):
...
...
@@ -50,101 +51,87 @@ class GroupByWindowTest(test_base.DatasetTestBase):
32
,
(
tensor_shape
.
TensorShape
([]),
tensor_shape
.
TensorShape
(
[
None
]),
tensor_shape
.
TensorShape
([
3
])))))
@
test_util
.
run_deprecated_v1
def
testSingleBucket
(
self
):
def
_map_fn
(
v
):
return
(
v
,
array_ops
.
fill
([
v
],
v
),
array_ops
.
fill
([
3
],
string_ops
.
as_string
(
v
)))
input_dataset
=
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
32
)).
map
(
_map_fn
)
)
input_dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
32
)).
map
(
_map_fn
)
bucketed_dataset
=
input_dataset
.
apply
(
grouping
.
group_by_window
(
lambda
x
,
y
,
z
:
0
,
lambda
k
,
bucket
:
self
.
_dynamicPad
(
k
,
bucket
,
32
),
32
))
get_next
=
self
.
getNext
(
bucketed_dataset
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
bucketed_dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
which_bucket
,
bucketed_values
=
self
.
evaluate
(
get_next
())
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
self
.
assertEqual
(
0
,
which_bucket
)
which_bucket
,
bucketed_values
=
self
.
evaluate
(
get_next
)
expected_scalar_int
=
np
.
arange
(
32
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
)).
astype
(
np
.
int64
)
for
i
in
range
(
32
):
expected_unk_int64
[
i
,
:
i
]
=
i
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
32
).
astype
(
bytes
)]).
T
self
.
assertEqual
(
0
,
which_bucket
)
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values
[
2
])
expected_scalar_int
=
np
.
arange
(
32
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
)).
astype
(
np
.
int64
)
for
i
in
range
(
32
):
expected_unk_int64
[
i
,
:
i
]
=
i
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
32
).
astype
(
bytes
)]).
T
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values
[
2
])
@
test_util
.
run_deprecated_v1
def
testEvenOddBuckets
(
self
):
def
_map_fn
(
v
):
return
(
v
,
array_ops
.
fill
([
v
],
v
),
array_ops
.
fill
([
3
],
string_ops
.
as_string
(
v
)))
input_dataset
=
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
64
)).
map
(
_map_fn
)
)
input_dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
64
)).
map
(
_map_fn
)
bucketed_dataset
=
input_dataset
.
apply
(
grouping
.
group_by_window
(
lambda
x
,
y
,
z
:
math_ops
.
cast
(
x
%
2
,
dtypes
.
int64
),
lambda
k
,
bucket
:
self
.
_dynamicPad
(
k
,
bucket
,
32
),
32
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
bucketed_dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even
,
bucketed_values_even
=
self
.
evaluate
(
get_next
)
which_bucket_odd
,
bucketed_values_odd
=
self
.
evaluate
(
get_next
)
# Count number of bucket_tensors.
self
.
assertEqual
(
3
,
len
(
bucketed_values_even
))
self
.
assertEqual
(
3
,
len
(
bucketed_values_odd
))
# Ensure bucket 0 was used for all minibatch entries.
self
.
assertAllEqual
(
0
,
which_bucket_even
)
self
.
assertAllEqual
(
1
,
which_bucket_odd
)
# Test the first bucket outputted, the events starting at 0
expected_scalar_int
=
np
.
arange
(
0
,
32
*
2
,
2
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
*
2
)).
astype
(
np
.
int64
)
for
i
in
range
(
0
,
32
):
expected_unk_int64
[
i
,
:
2
*
i
]
=
2
*
i
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
0
,
32
*
2
,
2
).
astype
(
bytes
)]).
T
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values_even
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values_even
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values_even
[
2
])
# Test the second bucket outputted, the odds starting at 1
expected_scalar_int
=
np
.
arange
(
1
,
32
*
2
+
1
,
2
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
*
2
+
1
)).
astype
(
np
.
int64
)
for
i
in
range
(
0
,
32
):
expected_unk_int64
[
i
,
:
2
*
i
+
1
]
=
2
*
i
+
1
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
1
,
32
*
2
+
1
,
2
).
astype
(
bytes
)]).
T
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values_odd
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values_odd
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values_odd
[
2
])
@
test_util
.
run_deprecated_v1
get_next
=
self
.
getNext
(
bucketed_dataset
)
# Get two minibatches (one containing even values, one containing odds)
which_bucket_even
,
bucketed_values_even
=
self
.
evaluate
(
get_next
())
which_bucket_odd
,
bucketed_values_odd
=
self
.
evaluate
(
get_next
())
# Count number of bucket_tensors.
self
.
assertEqual
(
3
,
len
(
bucketed_values_even
))
self
.
assertEqual
(
3
,
len
(
bucketed_values_odd
))
# Ensure bucket 0 was used for all minibatch entries.
self
.
assertAllEqual
(
0
,
which_bucket_even
)
self
.
assertAllEqual
(
1
,
which_bucket_odd
)
# Test the first bucket outputted, the events starting at 0
expected_scalar_int
=
np
.
arange
(
0
,
32
*
2
,
2
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
*
2
)).
astype
(
np
.
int64
)
for
i
in
range
(
0
,
32
):
expected_unk_int64
[
i
,
:
2
*
i
]
=
2
*
i
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
0
,
32
*
2
,
2
).
astype
(
bytes
)]).
T
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values_even
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values_even
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values_even
[
2
])
# Test the second bucket outputted, the odds starting at 1
expected_scalar_int
=
np
.
arange
(
1
,
32
*
2
+
1
,
2
,
dtype
=
np
.
int64
)
expected_unk_int64
=
np
.
zeros
((
32
,
31
*
2
+
1
)).
astype
(
np
.
int64
)
for
i
in
range
(
0
,
32
):
expected_unk_int64
[
i
,
:
2
*
i
+
1
]
=
2
*
i
+
1
expected_vec3_str
=
np
.
vstack
(
3
*
[
np
.
arange
(
1
,
32
*
2
+
1
,
2
).
astype
(
bytes
)]).
T
self
.
assertAllEqual
(
expected_scalar_int
,
bucketed_values_odd
[
0
])
self
.
assertAllEqual
(
expected_unk_int64
,
bucketed_values_odd
[
1
])
self
.
assertAllEqual
(
expected_vec3_str
,
bucketed_values_odd
[
2
])
def
testEvenOddBucketsFilterOutAllOdd
(
self
):
def
_map_fn
(
v
):
...
...
@@ -164,35 +151,28 @@ class GroupByWindowTest(test_base.DatasetTestBase):
"z"
:
tensor_shape
.
TensorShape
([
3
])
})))
input_dataset
=
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
128
)).
map
(
_map_fn
)
.
filter
(
lambda
d
:
math_ops
.
equal
(
d
[
"x"
]
%
2
,
0
)))
input_dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
math_ops
.
range
(
128
)).
map
(
_map_fn
).
filter
(
lambda
d
:
math_ops
.
equal
(
d
[
"x"
]
%
2
,
0
))
bucketed_dataset
=
input_dataset
.
apply
(
grouping
.
group_by_window
(
lambda
d
:
math_ops
.
cast
(
d
[
"x"
]
%
2
,
dtypes
.
int64
),
lambda
k
,
bucket
:
_dynamic_pad_fn
(
k
,
bucket
,
32
),
32
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
bucketed_dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
get_next
=
self
.
getNext
(
bucketed_dataset
)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0
,
bucketed_values_even0
=
self
.
evaluate
(
get_next
)
which_bucket1
,
bucketed_values_even1
=
self
.
evaluate
(
get_next
)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
which_bucket0
,
bucketed_values_even0
=
self
.
evaluate
(
get_next
()
)
which_bucket1
,
bucketed_values_even1
=
self
.
evaluate
(
get_next
()
)
# Ensure that bucket 1 was completely filtered out
self
.
assertAllEqual
(
0
,
which_bucket0
)
self
.
assertAllEqual
(
0
,
which_bucket1
)
self
.
assertAllEqual
(
np
.
arange
(
0
,
64
,
2
,
dtype
=
np
.
int64
),
bucketed_values_even0
[
"x"
])
self
.
assertAllEqual
(
np
.
arange
(
64
,
128
,
2
,
dtype
=
np
.
int64
),
bucketed_values_even1
[
"x"
])
# Ensure that bucket 1 was completely filtered out
self
.
assertAllEqual
(
0
,
which_bucket0
)
self
.
assertAllEqual
(
0
,
which_bucket1
)
self
.
assertAllEqual
(
np
.
arange
(
0
,
64
,
2
,
dtype
=
np
.
int64
),
bucketed_values_even0
[
"x"
])
self
.
assertAllEqual
(
np
.
arange
(
64
,
128
,
2
,
dtype
=
np
.
int64
),
bucketed_values_even1
[
"x"
])
@
test_util
.
run_deprecated_v1
def
testDynamicWindowSize
(
self
):
components
=
np
.
arange
(
100
).
astype
(
np
.
int64
)
...
...
@@ -207,111 +187,81 @@ class GroupByWindowTest(test_base.DatasetTestBase):
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
apply
(
grouping
.
group_by_window
(
lambda
x
:
x
%
2
,
lambda
_
,
xs
:
xs
.
batch
(
20
),
None
,
window_size_func
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
batches
=
0
while
True
:
result
=
self
.
evaluate
(
get_next
)
is_even
=
all
(
x
%
2
==
0
for
x
in
result
)
is_odd
=
all
(
x
%
2
==
1
for
x
in
result
)
self
.
assertTrue
(
is_even
or
is_odd
)
expected_batch_size
=
5
if
is_even
else
10
self
.
assertEqual
(
expected_batch_size
,
result
.
shape
[
0
])
batches
+=
1
self
.
assertEqual
(
batches
,
15
)
@
test_util
.
run_deprecated_v1
get_next
=
self
.
getNext
(
dataset
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
batches
=
0
while
True
:
result
=
self
.
evaluate
(
get_next
())
is_even
=
all
(
x
%
2
==
0
for
x
in
result
)
is_odd
=
all
(
x
%
2
==
1
for
x
in
result
)
self
.
assertTrue
(
is_even
or
is_odd
)
expected_batch_size
=
5
if
is_even
else
10
self
.
assertEqual
(
expected_batch_size
,
result
.
shape
[
0
])
batches
+=
1
self
.
assertEqual
(
batches
,
15
)
def
testSimple
(
self
):
components
=
np
.
random
.
randint
(
100
,
size
=
(
200
,)).
astype
(
np
.
int64
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
map
(
lambda
x
:
x
*
x
)
.
apply
(
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
map
(
lambda
x
:
x
*
x
).
apply
(
grouping
.
group_by_window
(
lambda
x
:
x
%
2
,
lambda
_
,
xs
:
xs
.
batch
(
4
),
4
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
counts
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
result
=
self
.
evaluate
(
get_next
)
self
.
assertTrue
(
all
(
x
%
2
==
0
for
x
in
result
)
or
all
(
x
%
2
==
1
)
for
x
in
result
)
counts
.
append
(
result
.
shape
[
0
])
self
.
assertEqual
(
len
(
components
),
sum
(
counts
))
num_full_batches
=
len
([
c
for
c
in
counts
if
c
==
4
])
self
.
assertGreaterEqual
(
num_full_batches
,
24
)
self
.
assertTrue
(
all
(
c
==
4
for
c
in
counts
[:
num_full_batches
]))
@
test_util
.
run_deprecated_v1
4
))
get_next
=
self
.
getNext
(
dataset
)
counts
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
result
=
self
.
evaluate
(
get_next
())
self
.
assertTrue
(
all
(
x
%
2
==
0
for
x
in
result
)
or
all
(
x
%
2
==
1
)
for
x
in
result
)
counts
.
append
(
result
.
shape
[
0
])
self
.
assertEqual
(
len
(
components
),
sum
(
counts
))
num_full_batches
=
len
([
c
for
c
in
counts
if
c
==
4
])
self
.
assertGreaterEqual
(
num_full_batches
,
24
)
self
.
assertTrue
(
all
(
c
==
4
for
c
in
counts
[:
num_full_batches
]))
def
testImmediateOutput
(
self
):
components
=
np
.
array
(
[
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
2
,
2
,
0
,
0
,
2
,
2
,
0
,
0
],
dtype
=
np
.
int64
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
repeat
(
-
1
).
apply
(
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
repeat
(
-
1
).
apply
(
grouping
.
group_by_window
(
lambda
x
:
x
%
3
,
lambda
_
,
xs
:
xs
.
batch
(
4
),
4
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for
_
in
range
(
3
):
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
))
self
.
assertAllEqual
([
1
,
1
,
1
,
1
],
self
.
evaluate
(
get_next
))
self
.
assertAllEqual
([
2
,
2
,
2
,
2
],
self
.
evaluate
(
get_next
))
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
))
@
test_util
.
run_deprecated_v1
4
))
get_next
=
self
.
getNext
(
dataset
)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
# 2. Different buckets can produce output at different rates, and
# 3. For deterministic input, the output is deterministic.
for
_
in
range
(
3
):
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
()))
self
.
assertAllEqual
([
1
,
1
,
1
,
1
],
self
.
evaluate
(
get_next
()))
self
.
assertAllEqual
([
2
,
2
,
2
,
2
],
self
.
evaluate
(
get_next
()))
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
()))
def
testSmallGroups
(
self
):
components
=
np
.
array
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
,
0
,
0
,
0
],
dtype
=
np
.
int64
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
apply
(
grouping
.
group_by_window
(
lambda
x
:
x
%
2
,
lambda
_
,
xs
:
xs
.
batch
(
4
),
4
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
))
self
.
assertAllEqual
([
1
,
1
,
1
,
1
],
self
.
evaluate
(
get_next
))
# The small outputs at the end are deterministically produced in key
# order.
self
.
assertAllEqual
([
0
,
0
,
0
],
self
.
evaluate
(
get_next
))
self
.
assertAllEqual
([
1
],
self
.
evaluate
(
get_next
))
@
test_util
.
run_deprecated_v1
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
apply
(
grouping
.
group_by_window
(
lambda
x
:
x
%
2
,
lambda
_
,
xs
:
xs
.
batch
(
4
),
4
))
get_next
=
self
.
getNext
(
dataset
)
self
.
assertAllEqual
([
0
,
0
,
0
,
0
],
self
.
evaluate
(
get_next
()))
self
.
assertAllEqual
([
1
,
1
,
1
,
1
],
self
.
evaluate
(
get_next
()))
# The small outputs at the end are deterministically produced in key
# order.
self
.
assertAllEqual
([
0
,
0
,
0
],
self
.
evaluate
(
get_next
()))
self
.
assertAllEqual
([
1
],
self
.
evaluate
(
get_next
()))
def
testEmpty
(
self
):
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
range
(
4
).
apply
(
grouping
.
group_by_window
(
lambda
_
:
0
,
lambda
_
,
xs
:
xs
,
0
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
with
self
.
assertRaisesRegexp
(
errors
.
InvalidArgumentError
,
"Window size must be greater than zero, but got 0."
):
print
(
self
.
evaluate
(
get_next
))
@
test_util
.
run_deprecated_v1
dataset
=
dataset_ops
.
Dataset
.
range
(
4
).
apply
(
grouping
.
group_by_window
(
lambda
_
:
0
,
lambda
_
,
xs
:
xs
,
0
))
get_next
=
self
.
getNext
(
dataset
)
with
self
.
assertRaisesRegexp
(
errors
.
InvalidArgumentError
,
"Window size must be greater than zero, but got 0."
):
print
(
self
.
evaluate
(
get_next
()))
def
testReduceFuncError
(
self
):
components
=
np
.
random
.
randint
(
100
,
size
=
(
200
,)).
astype
(
np
.
int64
)
...
...
@@ -323,19 +273,13 @@ class GroupByWindowTest(test_base.DatasetTestBase):
padded_shapes
=
(
tensor_shape
.
TensorShape
([]),
constant_op
.
constant
([
5
],
dtype
=
dtypes
.
int64
)
*
-
1
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
)
.
map
(
lambda
x
:
(
x
,
ops
.
convert_to_tensor
([
x
*
x
]))).
apply
(
grouping
.
group_by_window
(
lambda
x
,
_
:
x
%
2
,
reduce_func
,
32
))
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
(
)
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
map
(
lambda
x
:
(
x
,
ops
.
convert_to_tensor
([
x
*
x
]))).
apply
(
grouping
.
group_by_window
(
lambda
x
,
_
:
x
%
2
,
reduce_func
,
32
))
get_next
=
self
.
getNext
(
dataset
)
with
self
.
assertRaises
(
errors
.
InvalidArgumentError
):
self
.
evaluate
(
get_next
()
)
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
with
self
.
assertRaises
(
errors
.
InvalidArgumentError
):
self
.
evaluate
(
get_next
)
@
test_util
.
run_deprecated_v1
def
testConsumeWindowDatasetMoreThanOnce
(
self
):
components
=
np
.
random
.
randint
(
50
,
size
=
(
200
,)).
astype
(
np
.
int64
)
...
...
@@ -349,26 +293,23 @@ class GroupByWindowTest(test_base.DatasetTestBase):
4
,
padded_shapes
=
ops
.
convert_to_tensor
([(
key
+
1
)
*
10
])),
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
)
.
map
(
lambda
x
:
array_ops
.
fill
([
math_ops
.
cast
(
x
,
dtypes
.
int32
)],
x
))
.
apply
(
grouping
.
group_by_window
(
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
map
(
lambda
x
:
array_ops
.
fill
([
math_ops
.
cast
(
x
,
dtypes
.
int32
)],
x
)).
apply
(
grouping
.
group_by_window
(
lambda
x
:
math_ops
.
cast
(
array_ops
.
shape
(
x
)[
0
]
//
10
,
dtypes
.
int64
),
reduce_func
,
4
)))
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
counts
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
tight_result
,
multiple_of_10_result
=
self
.
evaluate
(
get_next
)
self
.
assertEqual
(
0
,
multiple_of_10_result
.
shape
[
1
]
%
10
)
self
.
assertAllEqual
(
tight_result
,
multiple_of_10_result
[:,
:
tight_result
.
shape
[
1
]])
counts
.
append
(
tight_result
.
shape
[
0
])
self
.
assertEqual
(
len
(
components
),
sum
(
counts
))
reduce_func
,
4
))
get_next
=
self
.
getNext
(
dataset
)
counts
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
tight_result
,
multiple_of_10_result
=
self
.
evaluate
(
get_next
())
self
.
assertEqual
(
0
,
multiple_of_10_result
.
shape
[
1
]
%
10
)
self
.
assertAllEqual
(
tight_result
,
multiple_of_10_result
[:,
:
tight_result
.
shape
[
1
]])
counts
.
append
(
tight_result
.
shape
[
0
])
self
.
assertEqual
(
len
(
components
),
sum
(
counts
))
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
浏览文件 @
e8f8e219
...
...
@@ -34,9 +34,9 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED
=
42
@
test_util
.
run_all_in_graph_and_eager_modes
class
IgnoreErrorsTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
def
testMapIgnoreError
(
self
):
components
=
np
.
array
([
1.
,
2.
,
3.
,
np
.
nan
,
5.
]).
astype
(
np
.
float32
)
...
...
@@ -44,18 +44,13 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
)
.
map
(
lambda
x
:
array_ops
.
check_numerics
(
x
,
"message"
)).
apply
(
error_ops
.
ignore_errors
()))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
for
x
in
[
1.
,
2.
,
3.
,
5.
]:
self
.
assertEqual
(
x
,
self
.
evaluate
(
get_next
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
@
test_util
.
run_deprecated_v1
get_next
=
self
.
getNext
(
dataset
)
for
x
in
[
1.
,
2.
,
3.
,
5.
]:
self
.
assertEqual
(
x
,
self
.
evaluate
(
get_next
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
def
testParallelMapIgnoreError
(
self
):
components
=
np
.
array
([
1.
,
2.
,
3.
,
np
.
nan
,
5.
]).
astype
(
np
.
float32
)
...
...
@@ -63,18 +58,13 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
dataset_ops
.
Dataset
.
from_tensor_slices
(
components
).
map
(
lambda
x
:
array_ops
.
check_numerics
(
x
,
"message"
),
num_parallel_calls
=
2
).
prefetch
(
2
).
apply
(
error_ops
.
ignore_errors
()))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
init_op
)
for
x
in
[
1.
,
2.
,
3.
,
5.
]:
self
.
assertEqual
(
x
,
self
.
evaluate
(
get_next
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
@
test_util
.
run_deprecated_v1
get_next
=
self
.
getNext
(
dataset
)
for
x
in
[
1.
,
2.
,
3.
,
5.
]:
self
.
assertEqual
(
x
,
self
.
evaluate
(
get_next
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
def
testReadFileIgnoreError
(
self
):
def
write_string_to_file
(
value
,
filename
):
...
...
@@ -91,28 +81,24 @@ class IgnoreErrorsTest(test_base.DatasetTestBase):
dataset_ops
.
Dataset
.
from_tensor_slices
(
filenames
).
map
(
io_ops
.
read_file
,
num_parallel_calls
=
2
).
prefetch
(
2
).
apply
(
error_ops
.
ignore_errors
()))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
init_op
=
iterator
.
initializer
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
# All of the files are present.
self
.
evaluate
(
init_op
)
for
filename
in
filenames
:
self
.
assertEqual
(
compat
.
as_bytes
(
filename
),
self
.
evaluate
(
get_next
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
# Delete one of the files.
os
.
remove
(
filenames
[
0
])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
self
.
evaluate
(
init_op
)
for
filename
in
filenames
[
1
:]:
self
.
assertEqual
(
compat
.
as_bytes
(
filename
),
self
.
evaluate
(
get_next
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
get_next
=
self
.
getNext
(
dataset
)
# All of the files are present.
for
filename
in
filenames
:
self
.
assertEqual
(
compat
.
as_bytes
(
filename
),
self
.
evaluate
(
get_next
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
# Delete one of the files.
os
.
remove
(
filenames
[
0
])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
get_next
=
self
.
getNext
(
dataset
)
for
filename
in
filenames
[
1
:]:
self
.
assertEqual
(
compat
.
as_bytes
(
filename
),
self
.
evaluate
(
get_next
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
浏览文件 @
e8f8e219
...
...
@@ -31,11 +31,11 @@ from tensorflow.python.framework import tensor_spec
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
check_ops
from
tensorflow.python.ops
import
data_flow_ops
from
tensorflow.python.ops
import
functional_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.platform
import
test
# TODO(b/119837791): add eager coverage.
class
MapDefunTest
(
test_base
.
DatasetTestBase
):
def
testMapDefunSimple
(
self
):
...
...
@@ -254,46 +254,5 @@ class MapDefunTest(test_base.DatasetTestBase):
self
.
assertAllEqual
(
self
.
evaluate
(
expected
),
self
.
evaluate
(
map_defun_op
))
class
MapDefunBenchmark
(
test
.
Benchmark
):
def
_run
(
self
,
op
,
name
=
None
,
num_iters
=
3000
):
with
session
.
Session
()
as
sess
:
# Warm up the session
for
_
in
range
(
5
):
self
.
evaluate
(
op
)
start
=
time
.
time
()
for
_
in
range
(
num_iters
):
self
.
evaluate
(
op
)
end
=
time
.
time
()
mean_us
=
(
end
-
start
)
*
1e6
/
num_iters
self
.
report_benchmark
(
name
=
name
,
iters
=
num_iters
,
wall_time
=
mean_us
,
extras
=
{
"examples_per_sec"
:
num_iters
/
(
end
-
start
)})
def
benchmarkDefunVsMapFn
(
self
):
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
@
function
.
defun
(
input_signature
=
[
tensor_spec
.
TensorSpec
([],
dtypes
.
int32
)])
def
defun
(
x
):
return
array_ops
.
identity
(
x
)
def
map_fn
(
x
):
return
array_ops
.
identity
(
x
)
base
=
math_ops
.
range
(
100
)
for
input_size
in
[
10
,
100
,
1000
,
10000
]:
num_iters
=
100000
//
input_size
map_defun_op
=
map_defun
.
map_defun
(
defun
,
[
base
],
[
dtypes
.
int32
],
[()])
map_fn_op
=
functional_ops
.
map_fn
(
map_fn
,
base
)
self
.
_run
(
map_defun_op
,
"benchmarkMapDefun_size_%d"
%
input_size
,
num_iters
=
num_iters
)
self
.
_run
(
map_fn_op
,
"benchmarkMapFn_size_%d"
%
input_size
,
num_iters
=
num_iters
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
浏览文件 @
e8f8e219
...
...
@@ -23,14 +23,14 @@ import tempfile
from
tensorflow.python.data.experimental.ops
import
matching_files
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
from
tensorflow.python.util
import
compat
class
MatchingFilesTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_all_in_graph_and_eager_modes
class
MatchingFilesDatasetTest
(
test_base
.
DatasetTestBase
):
def
setUp
(
self
):
self
.
tmp_dir
=
tempfile
.
mkdtemp
()
...
...
@@ -42,30 +42,23 @@ class MatchingFilesTest(test_base.DatasetTestBase):
for
filename
in
filenames
:
open
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
),
'a'
).
close
()
@
test_util
.
run_deprecated_v1
def
testNonExistingDirectory
(
self
):
"""Test the MatchingFiles dataset with a non-existing directory."""
self
.
tmp_dir
=
os
.
path
.
join
(
self
.
tmp_dir
,
'nonexistingdir'
)
dataset
=
matching_files
.
MatchingFilesDataset
(
os
.
path
.
join
(
self
.
tmp_dir
,
'*'
))
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
self
.
assertRaises
(
errors
.
NotFoundError
):
sess
.
run
(
next_element
)
self
.
assertDatasetProduces
(
dataset
,
expected_error
=
(
errors
.
NotFoundError
,
''
))
@
test_util
.
run_deprecated_v1
def
testEmptyDirectory
(
self
):
"""Test the MatchingFiles dataset with an empty directory."""
dataset
=
matching_files
.
MatchingFilesDataset
(
os
.
path
.
join
(
self
.
tmp_dir
,
'*'
))
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
self
.
assertRaises
(
errors
.
NotFoundError
):
sess
.
run
(
next_element
)
self
.
assertDatasetProduces
(
dataset
,
expected_error
=
(
errors
.
NotFoundError
,
''
))
@
test_util
.
run_deprecated_v1
def
testSimpleDirectory
(
self
):
"""Test the MatchingFiles dataset with a simple directory."""
...
...
@@ -74,21 +67,14 @@ class MatchingFilesTest(test_base.DatasetTestBase):
dataset
=
matching_files
.
MatchingFilesDataset
(
os
.
path
.
join
(
self
.
tmp_dir
,
'*'
))
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
self
.
assertDatasetProduces
(
dataset
,
expected_output
=
[
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
))
for
filename
in
filenames
],
assert_items_equal
=
True
)
expected_filenames
=
[]
actual_filenames
=
[]
for
filename
in
filenames
:
expected_filenames
.
append
(
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
)))
actual_filenames
.
append
(
compat
.
as_bytes
(
sess
.
run
(
next_element
)))
self
.
assertItemsEqual
(
expected_filenames
,
actual_filenames
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
sess
.
run
(
next_element
)
@
test_util
.
run_deprecated_v1
def
testFileSuffixes
(
self
):
"""Test the MatchingFiles dataset using the suffixes of filename."""
...
...
@@ -97,20 +83,14 @@ class MatchingFilesTest(test_base.DatasetTestBase):
dataset
=
matching_files
.
MatchingFilesDataset
(
os
.
path
.
join
(
self
.
tmp_dir
,
'*.py'
))
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
expected_filenames
=
[]
actual_filenames
=
[]
for
filename
in
filenames
[
1
:
-
1
]:
expected_filenames
.
append
(
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
)))
actual_filenames
.
append
(
compat
.
as_bytes
(
sess
.
run
(
next_element
)))
self
.
assertItemsEqual
(
expected_filenames
,
actual_filenames
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
sess
.
run
(
next_element
)
@
test_util
.
run_deprecated_v1
self
.
assertDatasetProduces
(
dataset
,
expected_output
=
[
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
))
for
filename
in
filenames
[
1
:
-
1
]
],
assert_items_equal
=
True
)
def
testFileMiddles
(
self
):
"""Test the MatchingFiles dataset using the middles of filename."""
...
...
@@ -119,20 +99,14 @@ class MatchingFilesTest(test_base.DatasetTestBase):
dataset
=
matching_files
.
MatchingFilesDataset
(
os
.
path
.
join
(
self
.
tmp_dir
,
'b*.py*'
))
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
expected_filenames
=
[]
actual_filenames
=
[]
for
filename
in
filenames
[
1
:
3
]:
expected_filenames
.
append
(
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
)))
actual_filenames
.
append
(
compat
.
as_bytes
(
sess
.
run
(
next_element
)))
self
.
assertItemsEqual
(
expected_filenames
,
actual_filenames
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
sess
.
run
(
next_element
)
@
test_util
.
run_deprecated_v1
self
.
assertDatasetProduces
(
dataset
,
expected_output
=
[
compat
.
as_bytes
(
os
.
path
.
join
(
self
.
tmp_dir
,
filename
))
for
filename
in
filenames
[
1
:
3
]
],
assert_items_equal
=
True
)
def
testNestedDirectories
(
self
):
"""Test the MatchingFiles dataset with nested directories."""
...
...
@@ -156,21 +130,20 @@ class MatchingFilesTest(test_base.DatasetTestBase):
]
dataset
=
matching_files
.
MatchingFilesDataset
(
patterns
)
with
self
.
cached_session
()
as
sess
:
next_element
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
expected_filenames
=
[
compat
.
as_bytes
(
filename
)
for
filename
in
filenames
if
filename
.
endswith
(
'.txt'
)
or
filename
.
endswith
(
'.log'
)
]
actual_filenames
=
[]
while
True
:
try
:
actual_filenames
.
append
(
compat
.
as_bytes
(
sess
.
run
(
next_element
)))
except
errors
.
OutOfRangeError
:
break
self
.
assertItemsEqual
(
expected_filenames
,
actual_filenames
)
next_element
=
self
.
getNext
(
dataset
)
expected_filenames
=
[
compat
.
as_bytes
(
filename
)
for
filename
in
filenames
if
filename
.
endswith
(
'.txt'
)
or
filename
.
endswith
(
'.log'
)
]
actual_filenames
=
[]
while
True
:
try
:
actual_filenames
.
append
(
compat
.
as_bytes
(
self
.
evaluate
(
next_element
())))
except
errors
.
OutOfRangeError
:
break
self
.
assertItemsEqual
(
expected_filenames
,
actual_filenames
)
if
__name__
==
'__main__'
:
...
...
tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
浏览文件 @
e8f8e219
...
...
@@ -35,6 +35,7 @@ from tensorflow.python.ops import script_ops
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
OverrideThreadpoolTest
(
test_base
.
DatasetTestBase
,
parameterized
.
TestCase
):
...
...
@@ -53,14 +54,12 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
lambda
x
:
script_ops
.
py_func
(
get_thread_id
,
[
x
],
dtypes
.
int64
),
num_parallel_calls
=
32
).
apply
(
unique
.
unique
()))
dataset
=
override_threadpool_fn
(
dataset
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
next_element
=
self
.
getNext
(
dataset
,
requires_initialization
=
True
)
self
.
evaluate
(
iterator
.
initializer
)
thread_ids
=
[]
try
:
while
True
:
thread_ids
.
append
(
self
.
evaluate
(
next_element
))
thread_ids
.
append
(
self
.
evaluate
(
next_element
()
))
except
errors
.
OutOfRangeError
:
pass
self
.
assertLen
(
thread_ids
,
len
(
set
(
thread_ids
)))
...
...
@@ -82,7 +81,6 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
(
"8"
,
4
,
1
),
(
"9"
,
4
,
4
),
)
@
test_util
.
run_deprecated_v1
def
testNumThreadsDeprecated
(
self
,
num_threads
,
max_intra_op_parallelism
):
def
override_threadpool_fn
(
dataset
):
...
...
@@ -109,7 +107,6 @@ class OverrideThreadpoolTest(test_base.DatasetTestBase,
(
"11"
,
4
,
4
),
(
"12"
,
None
,
None
),
)
@
test_util
.
run_deprecated_v1
def
testNumThreads
(
self
,
num_threads
,
max_intra_op_parallelism
):
def
override_threadpool_fn
(
dataset
):
...
...
tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
浏览文件 @
e8f8e219
...
...
@@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
from
tensorflow.python.platform
import
test
# TODO(b/119837791): add eager coverage when supported.
class
PrefetchToDeviceTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
...
...
tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
浏览文件 @
e8f8e219
...
...
@@ -17,11 +17,9 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
time
from
absl.testing
import
parameterized
import
numpy
as
np
from
six.moves
import
xrange
# pylint: disable=redefined-builtin
from
tensorflow.python.data.experimental.ops
import
resampling
from
tensorflow.python.data.kernel_tests
import
test_base
...
...
@@ -36,35 +34,12 @@ from tensorflow.python.platform import test
from
tensorflow.python.util
import
compat
def
_time_resampling
(
test_obj
,
data_np
,
target_dist
,
init_dist
,
num_to_sample
):
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
data_np
).
repeat
()
# Reshape distribution via rejection sampling.
dataset
=
dataset
.
apply
(
resampling
.
rejection_resample
(
class_func
=
lambda
x
:
x
,
target_dist
=
target_dist
,
initial_dist
=
init_dist
,
seed
=
142
))
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
()
with
test_obj
.
test_session
()
as
sess
:
start_time
=
time
.
time
()
for
_
in
xrange
(
num_to_sample
):
sess
.
run
(
get_next
)
end_time
=
time
.
time
()
return
end_time
-
start_time
@
test_util
.
run_all_in_graph_and_eager_modes
class
RejectionResampleTest
(
test_base
.
DatasetTestBase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
(
(
"InitialDistributionKnown"
,
True
),
(
"InitialDistributionUnknown"
,
False
))
@
test_util
.
run_deprecated_v1
def
testDistribution
(
self
,
initial_known
):
classes
=
np
.
random
.
randint
(
5
,
size
=
(
20000
,))
# Uniformly sampled
target_dist
=
[
0.9
,
0.05
,
0.05
,
0.0
,
0.0
]
...
...
@@ -73,17 +48,17 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
classes
).
shuffle
(
200
,
seed
=
21
).
map
(
lambda
c
:
(
c
,
string_ops
.
as_string
(
c
))).
repeat
()
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
.
apply
(
resampling
.
rejection_resample
(
target_dist
=
target_dist
,
initial_dist
=
initial_dist
,
class_func
=
lambda
c
,
_
:
c
,
seed
=
27
))).
get_next
()
get_next
=
self
.
getNext
(
dataset
.
apply
(
resampling
.
rejection_resample
(
target_dist
=
target_dist
,
initial_dist
=
initial_dist
,
class_func
=
lambda
c
,
_
:
c
,
seed
=
27
)))
with
self
.
cached_session
()
as
sess
:
returned
=
[]
while
len
(
returned
)
<
4000
:
returned
.
append
(
sess
.
run
(
get_next
))
returned
=
[]
while
len
(
returned
)
<
4000
:
returned
.
append
(
self
.
evaluate
(
get_next
()))
returned_classes
,
returned_classes_and_data
=
zip
(
*
returned
)
_
,
returned_data
=
zip
(
*
returned_classes_and_data
)
...
...
@@ -99,7 +74,6 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@
parameterized
.
named_parameters
(
(
"OnlyInitial"
,
True
),
(
"NotInitial"
,
False
))
@
test_util
.
run_deprecated_v1
def
testEdgeCasesSampleFromInitialDataset
(
self
,
only_initial_dist
):
init_dist
=
[
0.5
,
0.5
]
target_dist
=
[
0.5
,
0.5
]
if
only_initial_dist
else
[
0.0
,
1.0
]
...
...
@@ -117,15 +91,13 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
target_dist
=
target_dist
,
initial_dist
=
init_dist
))
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
(
)
get_next
=
self
.
getNext
(
dataset
)
with
self
.
cached_session
()
as
sess
:
returned
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
returned
.
append
(
sess
.
run
(
get_next
))
returned
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
returned
.
append
(
self
.
evaluate
(
get_next
()))
@
test_util
.
run_deprecated_v1
def
testRandomClasses
(
self
):
init_dist
=
[
0.25
,
0.25
,
0.25
,
0.25
]
target_dist
=
[
0.0
,
0.0
,
0.0
,
1.0
]
...
...
@@ -149,13 +121,12 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
target_dist
=
target_dist
,
initial_dist
=
init_dist
))
get_next
=
dataset_ops
.
make_one_shot_iterator
(
dataset
).
get_next
(
)
get_next
=
self
.
getNext
(
dataset
)
with
self
.
cached_session
()
as
sess
:
returned
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
returned
.
append
(
sess
.
run
(
get_next
))
returned
=
[]
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
while
True
:
returned
.
append
(
self
.
evaluate
(
get_next
()))
classes
,
_
=
zip
(
*
returned
)
bincount
=
np
.
bincount
(
...
...
@@ -165,22 +136,5 @@ class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
self
.
assertAllClose
(
target_dist
,
bincount
,
atol
=
1e-2
)
class
ResampleDatasetBenchmark
(
test
.
Benchmark
):
def
benchmarkResamplePerformance
(
self
):
init_dist
=
[
0.25
,
0.25
,
0.25
,
0.25
]
target_dist
=
[
0.0
,
0.0
,
0.0
,
1.0
]
num_classes
=
len
(
init_dist
)
# We don't need many samples to test a dirac-delta target distribution
num_samples
=
1000
data_np
=
np
.
random
.
choice
(
num_classes
,
num_samples
,
p
=
init_dist
)
resample_time
=
_time_resampling
(
self
,
data_np
,
target_dist
,
init_dist
,
num_to_sample
=
1000
)
self
.
report_benchmark
(
iters
=
1000
,
wall_time
=
resample_time
,
name
=
"benchmark_resample"
)
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
浏览文件 @
e8f8e219
...
...
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from
tensorflow.python.platform
import
test
# TODO(b/119837791): Add eager coverage
class
RestructuredDatasetTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
...
...
tensorflow/python/data/experimental/kernel_tests/scan_test.py
浏览文件 @
e8f8e219
...
...
@@ -24,7 +24,6 @@ import numpy as np
from
tensorflow.python.data.experimental.ops
import
scan_ops
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.eager
import
context
from
tensorflow.python.framework
import
constant_op
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
...
...
@@ -35,48 +34,34 @@ from tensorflow.python.ops import script_ops
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
ScanTest
(
test_base
.
DatasetTestBase
):
def
_counting_dataset
(
self
,
start
,
scan_fn
):
return
dataset_ops
.
Dataset
.
from_tensors
(
0
).
repeat
().
apply
(
scan_ops
.
scan
(
start
,
scan_fn
))
@
test_util
.
run_deprecated_v1
def
testCount
(
self
):
def
make_scan_fn
(
step
):
return
lambda
state
,
_
:
(
state
+
step
,
state
)
start
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
[])
step
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
[])
take
=
array_ops
.
placeholder
(
dtypes
.
int64
,
shape
=
[])
iterator
=
dataset_ops
.
make_initializable_iterator
(
self
.
_counting_dataset
(
start
,
make_scan_fn
(
step
)).
take
(
take
))
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
start_val
,
step_val
,
take_val
in
[(
0
,
1
,
10
),
(
0
,
1
,
0
),
(
10
,
1
,
10
),
(
10
,
2
,
10
),
(
10
,
-
1
,
10
),
(
10
,
-
2
,
10
)]:
sess
.
run
(
iterator
.
initializer
,
feed_dict
=
{
start
:
start_val
,
step
:
step_val
,
take
:
take_val
})
for
expected
,
_
in
zip
(
itertools
.
count
(
start_val
,
step_val
),
range
(
take_val
)):
self
.
assertEqual
(
expected
,
self
.
evaluate
(
next_element
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
@
test_util
.
run_in_graph_and_eager_modes
def
testFibonacci
(
self
):
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset_ops
.
Dataset
.
from_tensors
(
1
).
repeat
(
None
).
apply
(
scan_ops
.
scan
([
0
,
1
],
lambda
a
,
_
:
([
a
[
1
],
a
[
0
]
+
a
[
1
]],
a
[
1
]))))
def
dataset_fn
(
start
,
step
,
take
):
return
self
.
_counting_dataset
(
start
,
make_scan_fn
(
step
)).
take
(
take
)
if
context
.
executing_eagerly
():
next_element
=
iterator
.
get_next
else
:
get_next
=
iterator
.
get_next
()
next_element
=
lambda
:
get_next
for
start_val
,
step_val
,
take_val
in
[(
0
,
1
,
10
),
(
0
,
1
,
0
),
(
10
,
1
,
10
),
(
10
,
2
,
10
),
(
10
,
-
1
,
10
),
(
10
,
-
2
,
10
)]:
next_element
=
self
.
getNext
(
dataset_fn
(
start_val
,
step_val
,
take_val
))
for
expected
,
_
in
zip
(
itertools
.
count
(
start_val
,
step_val
),
range
(
take_val
)):
self
.
assertEqual
(
expected
,
self
.
evaluate
(
next_element
()))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
def
testFibonacci
(
self
):
data
=
dataset_ops
.
Dataset
.
from_tensors
(
1
).
repeat
(
None
).
apply
(
scan_ops
.
scan
([
0
,
1
],
lambda
a
,
_
:
([
a
[
1
],
a
[
0
]
+
a
[
1
]],
a
[
1
])))
next_element
=
self
.
getNext
(
data
)
self
.
assertEqual
(
1
,
self
.
evaluate
(
next_element
()))
self
.
assertEqual
(
1
,
self
.
evaluate
(
next_element
()))
...
...
@@ -85,8 +70,10 @@ class ScanTest(test_base.DatasetTestBase):
self
.
assertEqual
(
5
,
self
.
evaluate
(
next_element
()))
self
.
assertEqual
(
8
,
self
.
evaluate
(
next_element
()))
# TODO(b/119837791): Add coverage for eager.
@
test_util
.
run_deprecated_v1
def
testSparseCount
(
self
):
def
testSkipEagerSparseCount
(
self
):
def
_sparse
(
i
):
return
sparse_tensor
.
SparseTensorValue
(
indices
=
np
.
array
([[
0
,
0
]]),
...
...
@@ -96,27 +83,20 @@ class ScanTest(test_base.DatasetTestBase):
def
make_scan_fn
(
step
):
return
lambda
state
,
_
:
(
_sparse
(
state
.
values
[
0
]
+
step
),
state
)
start
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
[])
step
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
[])
take
=
array_ops
.
placeholder
(
dtypes
.
int64
,
shape
=
[])
iterator
=
dataset_ops
.
make_initializable_iterator
(
self
.
_counting_dataset
(
_sparse
(
start
),
make_scan_fn
(
step
)).
take
(
take
))
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
start_val
,
step_val
,
take_val
in
[(
0
,
1
,
10
),
(
0
,
1
,
0
),
(
10
,
1
,
10
),
(
10
,
2
,
10
),
(
10
,
-
1
,
10
),
(
10
,
-
2
,
10
)]:
sess
.
run
(
iterator
.
initializer
,
feed_dict
=
{
start
:
start_val
,
step
:
step_val
,
take
:
take_val
})
for
expected
,
_
in
zip
(
itertools
.
count
(
start_val
,
step_val
),
range
(
take_val
)):
self
.
assertEqual
(
expected
,
self
.
evaluate
(
next_element
).
values
[
0
])
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
def
dataset_fn
(
start
,
step
,
take
):
return
self
.
_counting_dataset
(
_sparse
(
start
),
make_scan_fn
(
step
)).
take
(
take
)
for
start_val
,
step_val
,
take_val
in
[(
0
,
1
,
10
),
(
0
,
1
,
0
),
(
10
,
1
,
10
),
(
10
,
2
,
10
),
(
10
,
-
1
,
10
),
(
10
,
-
2
,
10
)]:
next_element
=
self
.
getNext
(
dataset_fn
(
start_val
,
step_val
,
take_val
))
for
expected
,
_
in
zip
(
itertools
.
count
(
start_val
,
step_val
),
range
(
take_val
)):
self
.
assertEqual
(
expected
,
self
.
evaluate
(
next_element
()).
values
[
0
])
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
@
test_util
.
run_deprecated_v1
def
testChangingStateShape
(
self
):
# Test the fixed-point shape invariant calculations: start with
# initial values with known shapes, and use a scan function that
...
...
@@ -134,16 +114,14 @@ class ScanTest(test_base.DatasetTestBase):
self
.
assertIs
(
None
,
dataset
.
output_shapes
[
0
][
1
].
ndims
)
self
.
assertEqual
([],
dataset
.
output_shapes
[
1
].
as_list
())
iterator
=
dataset_ops
.
make_one_shot_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
next_element
=
self
.
getNext
(
dataset
)
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
5
):
(
longer_vector_val
,
larger_rank_val
),
_
=
self
.
evaluate
(
next_element
)
self
.
assertAllEqual
([
0
]
*
(
2
**
i
),
longer_vector_val
)
self
.
assertAllEqual
(
np
.
array
(
1
,
ndmin
=
i
),
larger_rank_val
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
for
i
in
range
(
5
):
(
longer_vector_val
,
larger_rank_val
),
_
=
self
.
evaluate
(
next_element
())
self
.
assertAllEqual
([
0
]
*
(
2
**
i
),
longer_vector_val
)
self
.
assertAllEqual
(
np
.
array
(
1
,
ndmin
=
i
),
larger_rank_val
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
def
testIncorrectStateType
(
self
):
...
...
tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
浏览文件 @
e8f8e219
...
...
@@ -23,11 +23,11 @@ from tensorflow.python.data.experimental.ops import shuffle_ops
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
ShuffleAndRepeatTest
(
test_base
.
DatasetTestBase
):
def
_build_ds
(
self
,
seed
,
count
=
5
,
num_elements
=
20
):
...
...
@@ -35,17 +35,15 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
shuffle_ops
.
shuffle_and_repeat
(
buffer_size
=
5
,
count
=
count
,
seed
=
seed
))
def
_gen_outputs
(
self
,
ds_fn
,
num_outputs
,
verify_exhausted
=
True
):
get_next
=
dataset_ops
.
make_one_shot_iterator
(
ds_fn
()).
get_next
(
)
get_next
=
self
.
getNext
(
ds_fn
()
)
outputs
=
[]
with
self
.
cached_session
()
as
sess
:
for
_
in
range
(
num_outputs
):
outputs
.
append
(
self
.
evaluate
(
get_next
))
if
verify_exhausted
:
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
)
for
_
in
range
(
num_outputs
):
outputs
.
append
(
self
.
evaluate
(
get_next
()))
if
verify_exhausted
:
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
get_next
())
return
outputs
@
test_util
.
run_deprecated_v1
def
testCorrectOutput
(
self
):
output
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
),
100
)
self
.
assertSequenceEqual
(
...
...
@@ -54,7 +52,6 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
for
i
in
range
(
5
):
self
.
assertSequenceEqual
(
sorted
(
output
[
i
*
20
:(
i
+
1
)
*
20
]),
range
(
20
))
@
test_util
.
run_deprecated_v1
def
testReshuffling
(
self
):
# Check that the output orders of different epochs are indeed different.
output
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
),
100
)
...
...
@@ -63,20 +60,17 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
epoch2
=
output
[(
i
+
1
)
*
20
:(
i
+
2
)
*
20
]
self
.
assertNotEqual
(
epoch1
,
epoch2
)
@
test_util
.
run_deprecated_v1
def
testSameOrderForSameSeeds
(
self
):
output1
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
),
100
)
output2
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
),
100
)
self
.
assertEqual
(
output1
,
output2
)
@
test_util
.
run_deprecated_v1
def
testDifferentOrderForDifferentSeeds
(
self
):
output1
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
),
100
)
output2
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
20
),
100
)
self
.
assertNotEqual
(
output1
,
output2
)
self
.
assertEqual
(
sorted
(
output1
),
sorted
(
output2
))
@
test_util
.
run_deprecated_v1
def
testCountNone
(
self
):
output1
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
,
count
=
None
),
100
,
verify_exhausted
=
False
)
...
...
@@ -85,7 +79,6 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
self
.
assertNotEqual
(
output1
,
output2
)
self
.
assertEqual
(
sorted
(
output1
),
sorted
(
output2
))
@
test_util
.
run_deprecated_v1
def
testCountMinusOne
(
self
):
output1
=
self
.
_gen_outputs
(
lambda
:
self
.
_build_ds
(
10
,
count
=-
1
),
100
,
verify_exhausted
=
False
)
...
...
@@ -110,12 +103,10 @@ class ShuffleAndRepeatTest(test_base.DatasetTestBase):
100
)
def
testLargeBufferSize
(
self
):
with
ops
.
Graph
().
as_default
()
as
g
:
ds
=
dataset_ops
.
Dataset
.
range
(
20
).
apply
(
shuffle_ops
.
shuffle_and_repeat
(
buffer_size
=
21
))
get_next_op
=
ds
.
make_one_shot_iterator
().
get_next
()
with
self
.
session
(
graph
=
g
)
as
sess
:
self
.
evaluate
(
get_next_op
)
ds
=
dataset_ops
.
Dataset
.
range
(
20
).
apply
(
shuffle_ops
.
shuffle_and_repeat
(
buffer_size
=
21
))
get_next
=
self
.
getNext
(
ds
)
self
.
evaluate
(
get_next
())
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/sleep_test.py
浏览文件 @
e8f8e219
...
...
@@ -29,25 +29,21 @@ from tensorflow.python.platform import test
_NUMPY_RANDOM_SEED
=
42
@
test_util
.
run_all_in_graph_and_eager_modes
class
SleepTest
(
test_base
.
DatasetTestBase
):
@
test_util
.
run_deprecated_v1
def
testSleep
(
self
):
sleep_microseconds
=
100
dataset
=
dataset_ops
.
Dataset
.
range
(
10
).
apply
(
sleep
.
sleep
(
sleep_microseconds
))
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
self
.
evaluate
(
iterator
.
initializer
)
start_time
=
time
.
time
()
for
i
in
range
(
10
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
next_element
))
end_time
=
time
.
time
()
self
.
assertGreater
(
end_time
-
start_time
,
(
10
*
sleep_microseconds
)
/
1e6
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
next_element
=
self
.
getNext
(
dataset
)
start_time
=
time
.
time
()
for
i
in
range
(
10
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
next_element
()))
end_time
=
time
.
time
()
self
.
assertGreater
(
end_time
-
start_time
,
(
10
*
sleep_microseconds
)
/
1e6
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
浏览文件 @
e8f8e219
此差异已折叠。
点击以展开。
tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
浏览文件 @
e8f8e219
...
...
@@ -22,7 +22,6 @@ import numpy as np
from
tensorflow.core.framework
import
summary_pb2
from
tensorflow.python.data.experimental.ops
import
stats_aggregator
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
errors
...
...
@@ -94,27 +93,23 @@ class StatsDatasetTestBase(test_base.DatasetTestBase):
aggregator
=
stats_aggregator
.
StatsAggregator
()
dataset
=
dataset_fn
()
dataset
=
dataset_transformation
(
dataset
,
aggregator
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
summary_t
=
aggregator
.
get_summary
()
next_element
=
self
.
getNext
(
dataset
,
requires_initialization
=
True
)
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
iterator
.
initializer
)
for
i
in
range
(
num_output
):
next_
=
sess
.
run
(
next_element
)
if
check_elements
:
self
.
assertAllEqual
(
np
.
array
([
i
]
*
i
,
dtype
=
np
.
int64
),
next_
)
summary_str
=
sess
.
run
(
summary_t
)
if
function_processing_time
:
self
.
_assertSummaryHasCountMoreOrEqualGeneralisedTag
(
summary_str
,
"::execution_time"
,
float
(
i
+
1
))
self
.
_assertSummaryContains
(
summary_str
,
dataset_name
+
"::num_parallel_calls"
)
self
.
_assertSummaryContains
(
summary_str
,
dataset_name
+
"::active_parallel_calls"
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
sess
.
run
(
next_element
)
for
i
in
range
(
num_output
):
next_
=
self
.
evaluate
(
next_element
())
if
check_elements
:
self
.
assertAllEqual
(
np
.
array
([
i
]
*
i
,
dtype
=
np
.
int64
),
next_
)
summary_str
=
self
.
evaluate
(
aggregator
.
get_summary
())
if
function_processing_time
:
summary_str
=
sess
.
run
(
summary_t
)
self
.
_assertSummaryHasCountMoreOrEqualGeneralisedTag
(
summary_str
,
"::execution_time"
,
float
(
num_output
))
summary_str
,
"::execution_time"
,
float
(
i
+
1
))
self
.
_assertSummaryContains
(
summary_str
,
dataset_name
+
"::num_parallel_calls"
)
self
.
_assertSummaryContains
(
summary_str
,
dataset_name
+
"::active_parallel_calls"
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
if
function_processing_time
:
summary_str
=
self
.
evaluate
(
aggregator
.
get_summary
())
self
.
_assertSummaryHasCountMoreOrEqualGeneralisedTag
(
summary_str
,
"::execution_time"
,
float
(
num_output
))
tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
浏览文件 @
e8f8e219
...
...
@@ -23,26 +23,24 @@ from tensorflow.python.data.experimental.ops import writers
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.data.ops
import
readers
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.lib.io
import
python_io
from
tensorflow.python.lib.io
import
tf_record
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.platform
import
test
from
tensorflow.python.util
import
compat
@
test_util
.
run_all_in_graph_and_eager_modes
class
TFRecordWriterTest
(
test_base
.
DatasetTestBase
):
def
setUp
(
self
):
super
(
TFRecordWriterTest
,
self
).
setUp
()
self
.
_num_records
=
7
self
.
filename
=
array_ops
.
placeholder
(
dtypes
.
string
,
shape
=
[])
self
.
compression_type
=
array_ops
.
placeholder_with_default
(
""
,
shape
=
[])
input_dataset
=
readers
.
TFRecordDataset
([
self
.
filename
],
self
.
compression_type
)
self
.
writer
=
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
self
.
compression_type
).
write
(
input_dataset
)
def
writer_fn
(
self
,
filename
,
compression_type
=
""
):
input_dataset
=
readers
.
TFRecordDataset
([
filename
],
compression_type
)
return
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
compression_type
).
write
(
input_dataset
)
def
_record
(
self
,
i
):
return
compat
.
as_bytes
(
"Record %d"
%
(
i
))
...
...
@@ -62,56 +60,39 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
return
os
.
path
.
join
(
self
.
get_temp_dir
(),
"tf_record.out.txt"
)
def
testWrite
(
self
):
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
self
.
writer
,
feed_dict
=
{
self
.
filename
:
self
.
_createFile
(),
})
self
.
evaluate
(
self
.
writer_fn
(
self
.
_createFile
()))
for
i
,
r
in
enumerate
(
tf_record
.
tf_record_iterator
(
self
.
_outputFilename
())):
self
.
assertAllEqual
(
self
.
_record
(
i
),
r
)
def
testWriteZLIB
(
self
):
options
=
tf_record
.
TFRecordOptions
(
tf_record
.
TFRecordCompressionType
.
ZLIB
)
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
self
.
writer
,
feed_dict
=
{
self
.
filename
:
self
.
_createFile
(
options
),
self
.
compression_type
:
"ZLIB"
,
})
self
.
evaluate
(
self
.
writer_fn
(
self
.
_createFile
(
options
),
compression_type
=
"ZLIB"
))
for
i
,
r
in
enumerate
(
tf_record
.
tf_record_iterator
(
self
.
_outputFilename
(),
options
=
options
)):
self
.
assertAllEqual
(
self
.
_record
(
i
),
r
)
def
testWriteGZIP
(
self
):
options
=
tf_record
.
TFRecordOptions
(
tf_record
.
TFRecordCompressionType
.
GZIP
)
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
self
.
writer
,
feed_dict
=
{
self
.
filename
:
self
.
_createFile
(
options
),
self
.
compression_type
:
"GZIP"
,
})
self
.
evaluate
(
self
.
writer_fn
(
self
.
_createFile
(
options
),
compression_type
=
"GZIP"
))
for
i
,
r
in
enumerate
(
tf_record
.
tf_record_iterator
(
self
.
_outputFilename
(),
options
=
options
)):
self
.
assertAllEqual
(
self
.
_record
(
i
),
r
)
def
testFailDataset
(
self
):
with
self
.
assertRaises
(
TypeError
):
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
self
.
compression_type
).
write
(
"whoops"
)
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
""
).
write
(
"whoops"
)
def
testFailDType
(
self
):
input_dataset
=
dataset_ops
.
Dataset
.
from_tensors
(
10
)
with
self
.
assertRaises
(
TypeError
):
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
self
.
compression_type
).
write
(
input_dataset
)
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
""
).
write
(
input_dataset
)
def
testFailShape
(
self
):
input_dataset
=
dataset_ops
.
Dataset
.
from_tensors
([[
"hello"
],
[
"world"
]])
with
self
.
assertRaises
(
TypeError
):
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
self
.
compression_type
).
write
(
input_dataset
)
writers
.
TFRecordWriter
(
self
.
_outputFilename
(),
""
).
write
(
input_dataset
)
if
__name__
==
"__main__"
:
...
...
tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
浏览文件 @
e8f8e219
...
...
@@ -36,24 +36,14 @@ from tensorflow.python.platform import test
from
tensorflow.python.util
import
compat
@
test_util
.
run_all_in_graph_and_eager_modes
class
UnbatchTest
(
test_base
.
DatasetTestBase
,
parameterized
.
TestCase
):
@
test_util
.
run_deprecated_v1
def
testUnbatchWithUnknownRankInput
(
self
):
placeholder
=
array_ops
.
placeholder
(
dtypes
.
int32
)
dataset
=
dataset_ops
.
Dataset
.
from_tensors
(
placeholder
).
apply
(
batching
.
unbatch
())
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_elem
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
sess
.
run
(
iterator
.
initializer
,
feed_dict
=
{
placeholder
:
[
0
,
1
,
2
,
3
]})
for
i
in
range
(
4
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
next_elem
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_elem
)
dataset
=
dataset_ops
.
Dataset
.
from_tensors
([
0
,
1
,
2
,
3
]).
apply
(
batching
.
unbatch
())
self
.
assertDatasetProduces
(
dataset
,
range
(
4
))
@
test_util
.
run_deprecated_v1
def
testUnbatchScalarDataset
(
self
):
data
=
tuple
([
math_ops
.
range
(
10
)
for
_
in
range
(
3
)])
data
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
data
)
...
...
@@ -63,17 +53,8 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data
=
data
.
apply
(
batching
.
unbatch
())
self
.
assertEqual
(
expected_types
,
data
.
output_types
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
op
=
iterator
.
get_next
()
self
.
assertDatasetProduces
(
data
,
[(
i
,)
*
3
for
i
in
range
(
10
)])
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
self
.
assertEqual
((
i
,)
*
3
,
self
.
evaluate
(
op
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
op
)
@
test_util
.
run_deprecated_v1
def
testUnbatchDatasetWithStrings
(
self
):
data
=
tuple
([
math_ops
.
range
(
10
)
for
_
in
range
(
3
)])
data
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
data
)
...
...
@@ -84,18 +65,12 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data
=
data
.
apply
(
batching
.
unbatch
())
self
.
assertEqual
(
expected_types
,
data
.
output_types
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
op
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
self
.
assertEqual
((
i
,
compat
.
as_bytes
(
str
(
i
)),
i
),
self
.
evaluate
(
op
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
op
)
self
.
assertDatasetProduces
(
data
,
[(
i
,
compat
.
as_bytes
(
str
(
i
)),
i
)
for
i
in
range
(
10
)])
# TODO(b/119837791): Add eager coverage.
@
test_util
.
run_deprecated_v1
def
testUnbatchDatasetWithSparseTensor
(
self
):
def
test
SkipEager
UnbatchDatasetWithSparseTensor
(
self
):
st
=
sparse_tensor
.
SparseTensorValue
(
indices
=
[[
i
,
i
]
for
i
in
range
(
10
)],
values
=
list
(
range
(
10
)),
...
...
@@ -107,17 +82,17 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
st_row
=
self
.
evaluate
(
next_element
)
self
.
assertEqual
([
i
],
st_row
.
indices
)
self
.
assertEqual
([
i
],
st_row
.
values
)
self
.
assertEqual
([
10
],
st_row
.
dense_shape
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
for
i
in
range
(
10
):
st_row
=
self
.
evaluate
(
next_element
)
self
.
assertEqual
([
i
],
st_row
.
indices
)
self
.
assertEqual
([
i
],
st_row
.
values
)
self
.
assertEqual
([
10
],
st_row
.
dense_shape
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
# TODO(b/119837791): Add eager coverage.
@
test_util
.
run_deprecated_v1
def
testUnbatchDatasetWithDenseAndSparseTensor
(
self
):
def
test
SkipEager
UnbatchDatasetWithDenseAndSparseTensor
(
self
):
st
=
sparse_tensor
.
SparseTensorValue
(
indices
=
[[
i
,
i
]
for
i
in
range
(
10
)],
values
=
list
(
range
(
10
)),
...
...
@@ -126,20 +101,17 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data
=
data
.
apply
(
batching
.
unbatch
())
data
=
data
.
batch
(
5
)
data
=
data
.
apply
(
batching
.
unbatch
())
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
next_element
=
iterator
.
get_next
()
next_element
=
self
.
getNext
(
data
)
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
dense_elem
,
st_row
=
self
.
evaluate
(
next_element
)
self
.
assertEqual
(
i
,
dense_elem
)
self
.
assertEqual
([
i
],
st_row
.
indices
)
self
.
assertEqual
([
i
],
st_row
.
values
)
self
.
assertEqual
([
10
],
st_row
.
dense_shape
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
for
i
in
range
(
10
):
dense_elem
,
st_row
=
self
.
evaluate
(
next_element
())
self
.
assertEqual
(
i
,
dense_elem
)
self
.
assertEqual
([
i
],
st_row
.
indices
)
self
.
assertEqual
([
i
],
st_row
.
values
)
self
.
assertEqual
([
10
],
st_row
.
dense_shape
)
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
())
@
test_util
.
run_deprecated_v1
def
testUnbatchSingleElementTupleDataset
(
self
):
data
=
tuple
([(
math_ops
.
range
(
10
),)
for
_
in
range
(
3
)])
data
=
dataset_ops
.
Dataset
.
from_tensor_slices
(
data
)
...
...
@@ -149,17 +121,8 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data
=
data
.
apply
(
batching
.
unbatch
())
self
.
assertEqual
(
expected_types
,
data
.
output_types
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
op
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
self
.
assertEqual
(((
i
,),)
*
3
,
self
.
evaluate
(
op
))
self
.
assertDatasetProduces
(
data
,
[((
i
,),)
*
3
for
i
in
range
(
10
)])
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
op
)
@
test_util
.
run_deprecated_v1
def
testUnbatchMultiElementTupleDataset
(
self
):
data
=
tuple
([(
math_ops
.
range
(
10
*
i
,
10
*
i
+
10
),
array_ops
.
fill
([
10
],
"hi"
))
for
i
in
range
(
3
)])
...
...
@@ -170,29 +133,16 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
data
=
data
.
apply
(
batching
.
unbatch
())
self
.
assertAllEqual
(
expected_types
,
data
.
output_types
)
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
op
=
iterator
.
get_next
()
self
.
assertDatasetProduces
(
data
,
[((
i
,
b
"hi"
),
(
10
+
i
,
b
"hi"
),
(
20
+
i
,
b
"hi"
))
for
i
in
range
(
10
)])
with
self
.
cached_session
()
as
sess
:
for
i
in
range
(
10
):
self
.
assertEqual
(((
i
,
b
"hi"
),
(
10
+
i
,
b
"hi"
),
(
20
+
i
,
b
"hi"
)),
self
.
evaluate
(
op
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
op
)
@
test_util
.
run_deprecated_v1
def
testUnbatchEmpty
(
self
):
data
=
dataset_ops
.
Dataset
.
from_tensors
(
(
constant_op
.
constant
([]),
constant_op
.
constant
([],
shape
=
[
0
,
4
]),
constant_op
.
constant
([],
shape
=
[
0
,
4
,
0
])))
data
=
data
.
apply
(
batching
.
unbatch
())
iterator
=
dataset_ops
.
make_one_shot_iterator
(
data
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
self
.
assertDatasetProduces
(
data
,
[])
def
testUnbatchStaticShapeMismatch
(
self
):
data
=
dataset_ops
.
Dataset
.
from_tensors
((
np
.
arange
(
7
),
np
.
arange
(
8
),
...
...
@@ -200,8 +150,9 @@ class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
with
self
.
assertRaises
(
ValueError
):
data
.
apply
(
batching
.
unbatch
())
# TODO(b/119837791): eager mode doesnt capture raised error, debug.
@
test_util
.
run_deprecated_v1
def
testUnbatchDynamicShapeMismatch
(
self
):
def
test
SkipEager
UnbatchDynamicShapeMismatch
(
self
):
ph1
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
[
None
])
ph2
=
array_ops
.
placeholder
(
dtypes
.
int32
,
shape
=
None
)
data
=
dataset_ops
.
Dataset
.
from_tensors
((
ph1
,
ph2
))
...
...
tensorflow/python/data/experimental/kernel_tests/unique_test.py
浏览文件 @
e8f8e219
...
...
@@ -21,12 +21,12 @@ from tensorflow.python.data.experimental.ops import unique
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.platform
import
test
from
tensorflow.python.util
import
compat
@
test_util
.
run_all_in_graph_and_eager_modes
class
UniqueTest
(
test_base
.
DatasetTestBase
):
def
_testSimpleHelper
(
self
,
dtype
,
test_cases
):
...
...
@@ -44,19 +44,13 @@ class UniqueTest(test_base.DatasetTestBase):
current_test_case
=
[]
dataset
=
dataset_ops
.
Dataset
.
from_generator
(
lambda
:
current_test_case
,
dtype
).
apply
(
unique
.
unique
())
iterator
=
dataset_ops
.
make_initializable_iterator
(
dataset
)
next_element
=
iterator
.
get_next
()
with
self
.
cached_session
()
as
sess
:
for
test_case
,
expected
in
test_cases
:
current_test_case
=
test_case
self
.
evaluate
(
iterator
.
initializer
)
for
element
in
expected
:
if
dtype
==
dtypes
.
string
:
element
=
compat
.
as_bytes
(
element
)
self
.
assertAllEqual
(
element
,
self
.
evaluate
(
next_element
))
with
self
.
assertRaises
(
errors
.
OutOfRangeError
):
self
.
evaluate
(
next_element
)
for
test_case
,
expected
in
test_cases
:
current_test_case
=
test_case
self
.
assertDatasetProduces
(
dataset
,
[
compat
.
as_bytes
(
element
)
if
dtype
==
dtypes
.
string
else
element
for
element
in
expected
])
@
test_util
.
run_deprecated_v1
def
testSimpleInt
(
self
):
...
...
tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
浏览文件 @
e8f8e219
...
...
@@ -20,11 +20,13 @@ from __future__ import print_function
from
tensorflow.python.data.kernel_tests
import
test_base
from
tensorflow.python.data.ops
import
dataset_ops
from
tensorflow.python.framework
import
ops
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
gen_dataset_ops
from
tensorflow.python.platform
import
test
@
test_util
.
run_all_in_graph_and_eager_modes
class
WrapDatasetVariantTest
(
test_base
.
DatasetTestBase
):
def
testBasic
(
self
):
...
...
@@ -36,15 +38,12 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase):
variant_ds
=
dataset_ops
.
_VariantDataset
(
unwrapped_variant
,
ds
.
_element_structure
)
iterator
=
dataset_ops
.
make_initializable_iterator
(
variant_ds
)
get_next
=
iterator
.
get_next
()
with
self
.
cached_session
():
self
.
evaluate
(
iterator
.
initializer
)
for
i
in
range
(
100
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
get_next
))
get_next
=
self
.
getNext
(
variant_ds
,
requires_initialization
=
True
)
for
i
in
range
(
100
):
self
.
assertEqual
(
i
,
self
.
evaluate
(
get_next
()))
def
testGPU
(
self
):
# TODO(b/119837791): add eager coverage when supported.
def
testSkipEagerGPU
(
self
):
ds
=
dataset_ops
.
Dataset
.
range
(
100
)
ds_variant
=
ds
.
_as_variant_tensor
()
# pylint: disable=protected-access
wrapped_variant
=
gen_dataset_ops
.
wrap_dataset_variant
(
ds_variant
)
...
...
tensorflow/python/data/kernel_tests/test_base.py
浏览文件 @
e8f8e219
...
...
@@ -88,6 +88,7 @@ class DatasetTestBase(test.TestCase):
def
assertDatasetProduces
(
self
,
dataset
,
expected_output
=
None
,
expected_shapes
=
None
,
expected_error
=
None
,
requires_initialization
=
False
,
num_test_iterations
=
1
,
...
...
@@ -98,6 +99,8 @@ class DatasetTestBase(test.TestCase):
dataset: A dataset to check for the expected output / error.
expected_output: A list of elements that the dataset is expected to
produce.
expected_shapes: A list of TensorShapes which is expected to match
output_shapes of dataset.
expected_error: A tuple `(type, predicate)` identifying the expected error
`dataset` should raise. The `type` should match the expected exception
type, while `predicate` should either be 1) a unary function that inputs
...
...
@@ -126,6 +129,8 @@ class DatasetTestBase(test.TestCase):
dataset
,
requires_initialization
=
requires_initialization
)
self
.
evaluate
(
get_next
())
return
if
expected_shapes
:
self
.
assertEqual
(
expected_shapes
,
dataset
.
output_shapes
)
self
.
assertGreater
(
num_test_iterations
,
0
)
for
_
in
range
(
num_test_iterations
):
get_next
=
self
.
getNext
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录