提交 59a714c6 编写于 作者: C Cathy Wong

Correct shuffle UT buffer_size > #dataset-row as valid

上级 7f8c9ebf
......@@ -98,6 +98,25 @@ def test_shuffle_04():
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_05():
"""
Test shuffle: buffer_size > number-of-rows-in-dataset
"""
logger.info("test_shuffle_05")
# define parameters
buffer_size = 13
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_exception_01():
"""
Test shuffle exception: buffer_size<0
......@@ -152,24 +171,6 @@ def test_shuffle_exception_03():
assert "buffer_size" in str(e)
def test_shuffle_exception_04():
"""
Test shuffle exception: buffer_size > number-of-rows-in-dataset
"""
logger.info("test_shuffle_exception_04")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR)
ds.config.set_seed(1)
try:
data1 = data1.shuffle(buffer_size=13)
sum([1 for _ in data1])
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "buffer_size" in str(e)
def test_shuffle_exception_05():
"""
Test shuffle exception: Missing mandatory buffer_size input parameter
......@@ -229,10 +230,10 @@ if __name__ == '__main__':
test_shuffle_02()
test_shuffle_03()
test_shuffle_04()
test_shuffle_05()
test_shuffle_exception_01()
test_shuffle_exception_02()
test_shuffle_exception_03()
test_shuffle_exception_04()
test_shuffle_exception_05()
test_shuffle_exception_06()
test_shuffle_exception_07()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册