提交 04eb4f89 编写于 作者: C Cathy Wong

Cleanup dataset UT: Replace save_and_check

上级 cda333f7
{
"deviceNum":4,
"deviceId": 2,
"shardConfig":"ALL",
"shuffle":"ON",
"seed": 0,
"epoch": 2
}
{
"deviceNum":4,
"deviceId": 2,
"shardConfig":"RANDOM",
"shuffle":"ON",
"seed": 0,
"epoch": 1
}
{
"deviceNum":4,
"deviceId": 2,
"shardConfig":"UNIQUE",
"shuffle":"ON",
"seed": 0,
"epoch": 3
}
{
"deviceNum":1,
"deviceId": 0,
"shardConfig":"RANDOM",
"shuffle":"OFF",
"seed": 0
}
......@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from util import save_and_check
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
GENERATE_GOLDEN = False
......@@ -33,9 +30,6 @@ def test_2ops_repeat_shuffle():
repeat_count = 2
buffer_size = 5
seed = 0
parameters = {"params": {'repeat_count': repeat_count,
'buffer_size': buffer_size,
'seed': seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -44,7 +38,7 @@ def test_2ops_repeat_shuffle():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "test_2ops_repeat_shuffle.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_2ops_shuffle_repeat():
......@@ -56,10 +50,6 @@ def test_2ops_shuffle_repeat():
repeat_count = 2
buffer_size = 5
seed = 0
parameters = {"params": {'repeat_count': repeat_count,
'buffer_size': buffer_size,
'reshuffle_each_iteration': False,
'seed': seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -68,7 +58,7 @@ def test_2ops_shuffle_repeat():
data1 = data1.repeat(repeat_count)
filename = "test_2ops_shuffle_repeat.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_2ops_repeat_batch():
......@@ -79,8 +69,6 @@ def test_2ops_repeat_batch():
# define parameters
repeat_count = 2
batch_size = 5
parameters = {"params": {'repeat_count': repeat_count,
'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -88,7 +76,7 @@ def test_2ops_repeat_batch():
data1 = data1.batch(batch_size, drop_remainder=True)
filename = "test_2ops_repeat_batch.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_2ops_batch_repeat():
......@@ -99,8 +87,6 @@ def test_2ops_batch_repeat():
# define parameters
repeat_count = 2
batch_size = 5
parameters = {"params": {'repeat_count': repeat_count,
'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -108,7 +94,7 @@ def test_2ops_batch_repeat():
data1 = data1.repeat(repeat_count)
filename = "test_2ops_batch_repeat.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_2ops_batch_shuffle():
......@@ -120,9 +106,6 @@ def test_2ops_batch_shuffle():
buffer_size = 5
seed = 0
batch_size = 2
parameters = {"params": {'buffer_size': buffer_size,
'seed': seed,
'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -131,7 +114,7 @@ def test_2ops_batch_shuffle():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "test_2ops_batch_shuffle.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_2ops_shuffle_batch():
......@@ -143,9 +126,6 @@ def test_2ops_shuffle_batch():
buffer_size = 5
seed = 0
batch_size = 2
parameters = {"params": {'buffer_size': buffer_size,
'seed': seed,
'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
......@@ -154,7 +134,7 @@ def test_2ops_shuffle_batch():
data1 = data1.batch(batch_size, drop_remainder=True)
filename = "test_2ops_shuffle_batch.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
if __name__ == '__main__':
......
......@@ -14,7 +14,7 @@
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check
from util import save_and_check_dict
# Note: Number of rows in test.data dataset: 12
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
......@@ -29,8 +29,6 @@ def test_batch_01():
# define parameters
batch_size = 2
drop_remainder = True
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -38,7 +36,7 @@ def test_batch_01():
assert sum([1 for _ in data1]) == 6
filename = "batch_01_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_02():
......@@ -49,8 +47,6 @@ def test_batch_02():
# define parameters
batch_size = 5
drop_remainder = True
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -58,7 +54,7 @@ def test_batch_02():
assert sum([1 for _ in data1]) == 2
filename = "batch_02_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_03():
......@@ -69,8 +65,6 @@ def test_batch_03():
# define parameters
batch_size = 3
drop_remainder = False
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -78,7 +72,7 @@ def test_batch_03():
assert sum([1 for _ in data1]) == 4
filename = "batch_03_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_04():
......@@ -89,8 +83,6 @@ def test_batch_04():
# define parameters
batch_size = 7
drop_remainder = False
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -98,7 +90,7 @@ def test_batch_04():
assert sum([1 for _ in data1]) == 2
filename = "batch_04_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_05():
......@@ -108,7 +100,6 @@ def test_batch_05():
logger.info("test_batch_05")
# define parameters
batch_size = 1
parameters = {"params": {'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -116,7 +107,7 @@ def test_batch_05():
assert sum([1 for _ in data1]) == 12
filename = "batch_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_06():
......@@ -127,8 +118,6 @@ def test_batch_06():
# define parameters
batch_size = 12
drop_remainder = False
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -136,7 +125,7 @@ def test_batch_06():
assert sum([1 for _ in data1]) == 1
filename = "batch_06_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_07():
......@@ -148,9 +137,6 @@ def test_batch_07():
batch_size = 4
drop_remainder = False
num_parallel_workers = 2
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder,
'num_parallel_workers': num_parallel_workers}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -159,7 +145,7 @@ def test_batch_07():
assert sum([1 for _ in data1]) == 3
filename = "batch_07_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_08():
......@@ -170,8 +156,6 @@ def test_batch_08():
# define parameters
batch_size = 6
num_parallel_workers = 1
parameters = {"params": {'batch_size': batch_size,
'num_parallel_workers': num_parallel_workers}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -179,7 +163,7 @@ def test_batch_08():
assert sum([1 for _ in data1]) == 2
filename = "batch_08_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_09():
......@@ -190,8 +174,6 @@ def test_batch_09():
# define parameters
batch_size = 13
drop_remainder = False
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -199,7 +181,7 @@ def test_batch_09():
assert sum([1 for _ in data1]) == 1
filename = "batch_09_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_10():
......@@ -210,8 +192,6 @@ def test_batch_10():
# define parameters
batch_size = 99
drop_remainder = True
parameters = {"params": {'batch_size': batch_size,
'drop_remainder': drop_remainder}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -219,7 +199,7 @@ def test_batch_10():
assert sum([1 for _ in data1]) == 0
filename = "batch_10_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_11():
......@@ -229,7 +209,6 @@ def test_batch_11():
logger.info("test_batch_11")
# define parameters
batch_size = 1
parameters = {"params": {'batch_size': batch_size}}
# apply dataset operations
# Use schema file with 1 row
......@@ -239,7 +218,7 @@ def test_batch_11():
assert sum([1 for _ in data1]) == 1
filename = "batch_11_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_12():
......@@ -249,7 +228,6 @@ def test_batch_12():
logger.info("test_batch_12")
# define parameters
batch_size = True
parameters = {"params": {'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -257,7 +235,7 @@ def test_batch_12():
assert sum([1 for _ in data1]) == 12
filename = "batch_12_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_exception_01():
......
......@@ -356,9 +356,13 @@ def test_clue_to_device():
if __name__ == "__main__":
test_clue()
test_clue_num_shards()
test_clue_num_samples()
test_textline_dataset_get_datasetsize()
test_clue_afqmc()
test_clue_cmnli()
test_clue_csl()
test_clue_iflytek()
test_clue_tnews()
test_clue_wsc()
test_clue_to_device()
......@@ -26,7 +26,7 @@ def generator_1d():
yield (np.array([i]),)
def test_case_0():
def test_generator_0():
"""
Test 1D Generator
"""
......@@ -48,7 +48,7 @@ def generator_md():
yield (np.array([[i, i + 1], [i + 2, i + 3]]),)
def test_case_1():
def test_generator_1():
"""
Test MD Generator
"""
......@@ -70,7 +70,7 @@ def generator_mc(maxid=64):
yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
def test_case_2():
def test_generator_2():
"""
Test multi column generator
"""
......@@ -88,7 +88,7 @@ def test_case_2():
i = i + 1
def test_case_3():
def test_generator_3():
"""
Test 1D Generator + repeat(4)
"""
......@@ -108,7 +108,7 @@ def test_case_3():
i = 0
def test_case_4():
def test_generator_4():
"""
Test fixed size 1D Generator + batch
"""
......@@ -146,7 +146,7 @@ def type_tester(t):
i = i + 4
def test_case_5():
def test_generator_5():
"""
Test 1D Generator on different data type
"""
......@@ -173,7 +173,7 @@ def type_tester_with_type_check(t, c):
i = i + 4
def test_case_6():
def test_generator_6():
"""
Test 1D Generator on different data type with type check
"""
......@@ -208,7 +208,7 @@ def type_tester_with_type_check_2c(t, c):
i = i + 4
def test_case_7():
def test_generator_7():
"""
Test 2 column Generator on different data type with type check
"""
......@@ -223,7 +223,7 @@ def test_case_7():
type_tester_with_type_check_2c(np_types[i], [None, de_types[i]])
def test_case_8():
def test_generator_8():
"""
Test multi column generator with few mapops
"""
......@@ -249,7 +249,7 @@ def test_case_8():
i = i + 1
def test_case_9():
def test_generator_9():
"""
Test map column order when len(input_columns) == len(output_columns).
"""
......@@ -280,7 +280,7 @@ def test_case_9():
i = i + 1
def test_case_10():
def test_generator_10():
"""
Test map column order when len(input_columns) != len(output_columns).
"""
......@@ -303,7 +303,7 @@ def test_case_10():
i = i + 1
def test_case_11():
def test_generator_11():
"""
Test map column order when len(input_columns) != len(output_columns).
"""
......@@ -327,7 +327,7 @@ def test_case_11():
i = i + 1
def test_case_12():
def test_generator_12():
"""
Test map column order when input_columns and output_columns are None.
"""
......@@ -361,7 +361,7 @@ def test_case_12():
i = i + 1
def test_case_13():
def test_generator_13():
"""
Test map column order when input_columns is None.
"""
......@@ -391,7 +391,7 @@ def test_case_13():
i = i + 1
def test_case_14():
def test_generator_14():
"""
Test 1D Generator MP + CPP sampler
"""
......@@ -408,7 +408,7 @@ def test_case_14():
i = 0
def test_case_15():
def test_generator_15():
"""
Test 1D Generator MP + Python sampler
"""
......@@ -426,7 +426,7 @@ def test_case_15():
i = 0
def test_case_16():
def test_generator_16():
"""
Test multi column generator Mp + CPP sampler
"""
......@@ -445,7 +445,7 @@ def test_case_16():
i = i + 1
def test_case_17():
def test_generator_17():
"""
Test multi column generator Mp + Python sampler
"""
......@@ -465,7 +465,7 @@ def test_case_17():
i = i + 1
def test_case_error_1():
def test_generator_error_1():
def generator_np():
for i in range(64):
yield (np.array([{i}]),)
......@@ -477,7 +477,7 @@ def test_case_error_1():
assert "Invalid data type" in str(info.value)
def test_case_error_2():
def test_generator_error_2():
def generator_np():
for i in range(64):
yield ({i},)
......@@ -489,7 +489,7 @@ def test_case_error_2():
assert "Generator should return a tuple of numpy arrays" in str(info.value)
def test_case_error_3():
def test_generator_error_3():
with pytest.raises(ValueError) as info:
# apply dataset operations
data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
......@@ -501,7 +501,7 @@ def test_case_error_3():
assert "When (len(input_columns) != len(output_columns)), columns_order must be specified." in str(info.value)
def test_case_error_4():
def test_generator_error_4():
with pytest.raises(RuntimeError) as info:
# apply dataset operations
data1 = ds.GeneratorDataset(generator_mc(2048), ["label", "image"])
......@@ -513,7 +513,7 @@ def test_case_error_4():
assert "Unexpected error. Result of a tensorOp doesn't match output column names" in str(info.value)
def test_sequential_sampler():
def test_generator_sequential_sampler():
source = [(np.array([x]),) for x in range(64)]
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler())
i = 0
......@@ -523,14 +523,14 @@ def test_sequential_sampler():
i = i + 1
def test_random_sampler():
def test_generator_random_sampler():
source = [(np.array([x]),) for x in range(64)]
ds1 = ds.GeneratorDataset(source, ["data"], shuffle=True)
for _ in ds1.create_dict_iterator(): # each data is a dictionary
pass
def test_distributed_sampler():
def test_generator_distributed_sampler():
source = [(np.array([x]),) for x in range(64)]
for sid in range(8):
ds1 = ds.GeneratorDataset(source, ["data"], shuffle=False, num_shards=8, shard_id=sid)
......@@ -541,7 +541,7 @@ def test_distributed_sampler():
i = i + 8
def test_num_samples():
def test_generator_num_samples():
source = [(np.array([x]),) for x in range(64)]
num_samples = 32
ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples))
......@@ -564,7 +564,7 @@ def test_num_samples():
assert count == num_samples
def test_num_samples_underflow():
def test_generator_num_samples_underflow():
source = [(np.array([x]),) for x in range(64)]
num_samples = 256
ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(64)], num_samples=num_samples)
......@@ -600,7 +600,7 @@ def type_tester_with_type_check_2c_schema(t, c):
i = i + 4
def test_schema():
def test_generator_schema():
"""
Test 2 column Generator on different data type with type check with schema input
"""
......@@ -615,9 +615,9 @@ def test_schema():
type_tester_with_type_check_2c_schema(np_types[i], [de_types[i], de_types[i]])
def manual_test_keyborad_interrupt():
def manual_test_generator_keyboard_interrupt():
"""
Test keyborad_interrupt
Test keyboard_interrupt
"""
logger.info("Test 1D Generator MP : 0 - 63")
......@@ -635,31 +635,31 @@ def manual_test_keyborad_interrupt():
if __name__ == "__main__":
test_case_0()
test_case_1()
test_case_2()
test_case_3()
test_case_4()
test_case_5()
test_case_6()
test_case_7()
test_case_8()
test_case_9()
test_case_10()
test_case_11()
test_case_12()
test_case_13()
test_case_14()
test_case_15()
test_case_16()
test_case_17()
test_case_error_1()
test_case_error_2()
test_case_error_3()
test_case_error_4()
test_sequential_sampler()
test_distributed_sampler()
test_random_sampler()
test_num_samples()
test_num_samples_underflow()
test_schema()
test_generator_0()
test_generator_1()
test_generator_2()
test_generator_3()
test_generator_4()
test_generator_5()
test_generator_6()
test_generator_7()
test_generator_8()
test_generator_9()
test_generator_10()
test_generator_11()
test_generator_12()
test_generator_13()
test_generator_14()
test_generator_15()
test_generator_16()
test_generator_17()
test_generator_error_1()
test_generator_error_2()
test_generator_error_3()
test_generator_error_4()
test_generator_sequential_sampler()
test_generator_distributed_sampler()
test_generator_random_sampler()
test_generator_num_samples()
test_generator_num_samples_underflow()
test_generator_schema()
......@@ -33,7 +33,7 @@ def check(project_columns):
assert all([np.array_equal(d1, d2) for d1, d2 in zip(data_actual, data_expected)])
def test_case_iterator():
def test_iterator_create_tuple():
"""
Test creating tuple iterator
"""
......@@ -95,7 +95,9 @@ class MyDict(dict):
def test_tree_copy():
# Testing copying the tree with a pyfunc that cannot be pickled
"""
Testing copying the tree with a pyfunc that cannot be pickled
"""
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS)
data1 = data.map(operations=[MyDict()])
......@@ -110,4 +112,6 @@ def test_tree_copy():
if __name__ == '__main__':
test_iterator_create_tuple()
test_iterator_weak_ref()
test_tree_copy()
......@@ -13,10 +13,9 @@
# limitations under the License.
# ==============================================================================
import numpy as np
from util import save_and_check
import mindspore.dataset as ds
from mindspore import log as logger
from util import save_and_check_dict
# Note: Number of rows in test.data dataset: 12
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
......@@ -31,7 +30,6 @@ def test_shuffle_01():
# define parameters
buffer_size = 5
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -39,7 +37,7 @@ def test_shuffle_01():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_01_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_02():
......@@ -50,7 +48,6 @@ def test_shuffle_02():
# define parameters
buffer_size = 12
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -58,7 +55,7 @@ def test_shuffle_02():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_02_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_03():
......@@ -69,7 +66,6 @@ def test_shuffle_03():
# define parameters
buffer_size = 2
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -77,7 +73,7 @@ def test_shuffle_03():
data1 = data1.shuffle(buffer_size)
filename = "shuffle_03_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_04():
......@@ -88,7 +84,6 @@ def test_shuffle_04():
# define parameters
buffer_size = 2
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, num_samples=2)
......@@ -96,7 +91,7 @@ def test_shuffle_04():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_04_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_05():
......@@ -107,7 +102,6 @@ def test_shuffle_05():
# define parameters
buffer_size = 13
seed = 1
parameters = {"params": {'buffer_size': buffer_size, "seed": seed}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
......@@ -115,7 +109,7 @@ def test_shuffle_05():
data1 = data1.shuffle(buffer_size=buffer_size)
filename = "shuffle_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_06():
......
......@@ -24,9 +24,6 @@ import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
# These are the column names defined in the testTFTestAllTypes dataset
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
# These are list of plot title in different visualize modes
PLOT_TITLE_DICT = {
1: ["Original image", "Transformed image"],
......@@ -82,39 +79,6 @@ def _save_json(filename, parameters, result_dict):
fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
def save_and_check(data, parameters, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
Use create_dict_iterator to access the dataset.
Note: save_and_check() is deprecated; use save_and_check_dict().
"""
num_iter = 0
result_dict = {}
for column_name in COLUMNS:
result_dict[column_name] = []
for item in data.create_dict_iterator(): # each data is a dictionary
for data_key in list(item.keys()):
if data_key not in result_dict:
result_dict[data_key] = []
result_dict[data_key].append(item[data_key].tolist())
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
cur_dir = os.path.dirname(os.path.realpath(__file__))
golden_ref_dir = os.path.join(cur_dir, "../../data/dataset", 'golden', filename)
if generate_golden:
# Save as the golden result
_save_golden(cur_dir, golden_ref_dir, result_dict)
_compare_to_golden(golden_ref_dir, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
_save_json(filename, parameters, result_dict)
def save_and_check_dict(data, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as dictionary) with golden file.
......@@ -203,6 +167,29 @@ def save_and_check_tuple(data, parameters, filename, generate_golden=False):
_save_json(filename, parameters, result_dict)
def config_get_set_seed(seed_new):
"""
Get and return the original configuration seed value.
Set the new configuration seed value.
"""
seed_original = ds.config.get_seed()
ds.config.set_seed(seed_new)
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
return seed_original
def config_get_set_num_parallel_workers(num_parallel_workers_new):
"""
Get and return the original configuration num_parallel_workers value.
Set the new configuration num_parallel_workers value.
"""
num_parallel_workers_original = ds.config.get_num_parallel_workers()
ds.config.set_num_parallel_workers(num_parallel_workers_new)
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
num_parallel_workers_new))
return num_parallel_workers_original
def diff_mse(in1, in2):
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
return mse * 100
......@@ -265,29 +252,6 @@ def visualize_image(image_original, image_de, mse=None, image_lib=None):
plt.show()
def config_get_set_seed(seed_new):
"""
Get and return the original configuration seed value.
Set the new configuration seed value.
"""
seed_original = ds.config.get_seed()
ds.config.set_seed(seed_new)
logger.info("seed: original = {} new = {} ".format(seed_original, seed_new))
return seed_original
def config_get_set_num_parallel_workers(num_parallel_workers_new):
"""
Get and return the original configuration num_parallel_workers value.
Set the new configuration num_parallel_workers value.
"""
num_parallel_workers_original = ds.config.get_num_parallel_workers()
ds.config.set_num_parallel_workers(num_parallel_workers_new)
logger.info("num_parallel_workers: original = {} new = {} ".format(num_parallel_workers_original,
num_parallel_workers_new))
return num_parallel_workers_original
def visualize_with_bounding_boxes(orig, aug, annot_name="annotation", plot_rows=3):
"""
Take a list of un-augmented and augmented images with "annotation" bounding boxes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册