提交 772e6c14 编写于 作者: C Cathy Wong

Cleanup dataset UT: test_batch, save_and_check support

上级 651e9081
......@@ -37,6 +37,7 @@ def test_batch_01():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder)
assert sum([1 for _ in data1]) == 6
filename = "batch_01_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -56,6 +57,7 @@ def test_batch_02():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 2
filename = "batch_02_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -75,6 +77,7 @@ def test_batch_03():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 4
filename = "batch_03_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -94,6 +97,7 @@ def test_batch_04():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder)
assert sum([1 for _ in data1]) == 2
filename = "batch_04_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -111,6 +115,7 @@ def test_batch_05():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size)
assert sum([1 for _ in data1]) == 12
filename = "batch_05_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -130,6 +135,7 @@ def test_batch_06():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
assert sum([1 for _ in data1]) == 1
filename = "batch_06_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -152,6 +158,7 @@ def test_batch_07():
data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
batch_size=batch_size)
assert sum([1 for _ in data1]) == 3
filename = "batch_07_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -171,6 +178,7 @@ def test_batch_08():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
assert sum([1 for _ in data1]) == 2
filename = "batch_08_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -190,6 +198,7 @@ def test_batch_09():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 1
filename = "batch_09_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -209,6 +218,7 @@ def test_batch_10():
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
assert sum([1 for _ in data1]) == 0
filename = "batch_10_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
......@@ -228,10 +238,30 @@ def test_batch_11():
data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
data1 = data1.batch(batch_size)
assert sum([1 for _ in data1]) == 1
filename = "batch_11_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_12():
"""
Test batch: batch_size boolean value True, treated as valid value 1
"""
logger.info("test_batch_12")
# define parameters
batch_size = True
parameters = {"params": {'batch_size': batch_size}}
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data1 = data1.batch(batch_size=batch_size)
assert sum([1 for _ in data1]) == 12
filename = "batch_12_result.npz"
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_batch_exception_01():
"""
Test batch exception: num_parallel_workers=0
......@@ -302,7 +332,7 @@ def test_batch_exception_04():
def test_batch_exception_05():
"""
Test batch exception: batch_size wrong type, boolean value False
Test batch exception: batch_size boolean value False, treated as invalid value 0
"""
logger.info("test_batch_exception_05")
......@@ -317,23 +347,6 @@ def test_batch_exception_05():
assert "batch_size" in str(e)
def skip_test_batch_exception_06():
"""
Test batch exception: batch_size wrong type, boolean value True
"""
logger.info("test_batch_exception_06")
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
try:
data1 = data1.batch(batch_size=True)
sum([1 for _ in data1])
except BaseException as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "batch_size" in str(e)
def test_batch_exception_07():
"""
Test batch exception: drop_remainder wrong type
......@@ -473,12 +486,12 @@ if __name__ == '__main__':
test_batch_09()
test_batch_10()
test_batch_11()
test_batch_12()
test_batch_exception_01()
test_batch_exception_02()
test_batch_exception_03()
test_batch_exception_04()
test_batch_exception_05()
skip_test_batch_exception_06()
test_batch_exception_07()
test_batch_exception_08()
test_batch_exception_09()
......
......@@ -28,7 +28,7 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_center_crop_op(height=375, width=375, plot=False):
"""
Test random_vertical
Test CenterCrop
"""
logger.info("Test CenterCrop")
......@@ -55,7 +55,7 @@ def test_center_crop_op(height=375, width=375, plot=False):
def test_center_crop_md5(height=375, width=375):
"""
Test random_vertical
Test CenterCrop
"""
logger.info("Test CenterCrop")
......@@ -69,13 +69,12 @@ def test_center_crop_md5(height=375, width=375):
# expected md5 from images
filename = "test_center_crop_01_result.npz"
parameters = {"params": {}}
save_and_check_md5(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)
def test_center_crop_comp(height=375, width=375, plot=False):
"""
Test random_vertical between python and c image augmentation
Test CenterCrop between python and c image augmentation
"""
logger.info("Test CenterCrop")
......@@ -114,4 +113,4 @@ if __name__ == "__main__":
test_center_crop_op(300, 600)
test_center_crop_op(600, 300)
test_center_crop_md5(600, 600)
test_center_crop_comp()
\ No newline at end of file
test_center_crop_comp()
......@@ -18,9 +18,7 @@ Testing Decode op in DE
import cv2
import mindspore.dataset.transforms.vision.c_transforms as vision
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing RandomRotation op in DE
Testing RandomColorAdjust op in DE
"""
import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as c_vision
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing RandomRotation op in DE
Testing RandomErasing op in DE
"""
import matplotlib.pyplot as plt
import numpy as np
......
......@@ -17,9 +17,7 @@ Testing the resize op in DE
"""
import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
Testing RandomRotation op in DE
Testing TypeCast op in DE
"""
import mindspore.dataset.transforms.vision.c_transforms as c_vision
import mindspore.dataset.transforms.vision.py_transforms as py_vision
......@@ -31,9 +31,9 @@ SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_type_cast():
"""
Test type_cast_op
Test TypeCast op
"""
logger.info("test_type_cast_op")
logger.info("test_type_cast")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
......@@ -71,9 +71,9 @@ def test_type_cast():
def test_type_cast_string():
"""
Test type_cast_op
Test TypeCast op
"""
logger.info("test_type_cast_op")
logger.info("test_type_cast_string")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
......
......@@ -44,8 +44,7 @@ def test_zip_01():
dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 7 columns
filename = "zip_01_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_02():
......@@ -59,8 +58,7 @@ def test_zip_02():
dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 4 columns
filename = "zip_02_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_03():
......@@ -74,8 +72,7 @@ def test_zip_03():
dataz = ds.zip((data1, data2))
# Note: zipped dataset has 3 rows and 7 columns
filename = "zip_03_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_04():
......@@ -90,8 +87,7 @@ def test_zip_04():
dataz = ds.zip((data1, data2, data3))
# Note: zipped dataset has 3 rows and 9 columns
filename = "zip_04_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_05():
......@@ -109,8 +105,7 @@ def test_zip_05():
dataz = ds.zip((data1, data2))
# Note: zipped dataset has 5 rows and 9 columns
filename = "zip_05_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_06():
......@@ -129,8 +124,7 @@ def test_zip_06():
dataz = dataz.repeat(2)
# Note: resultant dataset has 10 rows and 9 columns
filename = "zip_06_result.npz"
parameters = {"params": {}}
save_and_check_dict(dataz, parameters, filename, generate_golden=GENERATE_GOLDEN)
save_and_check_dict(dataz, filename, generate_golden=GENERATE_GOLDEN)
def test_zip_exception_01():
......
......@@ -15,15 +15,15 @@
import json
import os
import hashlib
import numpy as np
import matplotlib.pyplot as plt
import hashlib
#import jsbeautifier
from mindspore import log as logger
COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
"col_sint16", "col_sint32", "col_sint64"]
SAVE_JSON = False
def save_golden(cur_dir, golden_ref_dir, result_dict):
......@@ -44,15 +44,6 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict):
np.savez(golden_ref_dir, np.array(list(result_dict.items())))
def save_golden_md5(cur_dir, golden_ref_dir, result_dict):
"""
Save the dictionary (both keys and values) as the golden result in .npz file
"""
logger.info("cur_dir is {}".format(cur_dir))
logger.info("golden_ref_dir is {}".format(golden_ref_dir))
np.savez(golden_ref_dir, np.array(list(result_dict.items())))
def compare_to_golden(golden_ref_dir, result_dict):
"""
Compare as numpy arrays the test result to the golden result
......@@ -67,7 +58,7 @@ def compare_to_golden_dict(golden_ref_dir, result_dict):
Compare as dictionaries the test result to the golden result
"""
golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0']
np.testing.assert_equal (result_dict, dict(golden_array))
np.testing.assert_equal(result_dict, dict(golden_array))
# assert result_dict == dict(golden_array)
......@@ -83,7 +74,6 @@ def save_json(filename, parameters, result_dict):
fout.write(jsbeautifier.beautify(json.dumps(out_dict), options))
def save_and_check(data, parameters, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as numpy array) with golden file.
......@@ -111,11 +101,12 @@ def save_and_check(data, parameters, filename, generate_golden=False):
compare_to_golden(golden_ref_dir, result_dict)
# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
save_json(filename, parameters, result_dict)
def save_and_check_dict(data, parameters, filename, generate_golden=False):
def save_and_check_dict(data, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as dictionary) with golden file.
Use create_dict_iterator to access the dataset.
......@@ -140,11 +131,13 @@ def save_and_check_dict(data, parameters, filename, generate_golden=False):
compare_to_golden_dict(golden_ref_dir, result_dict)
# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
parameters = {"params": {}}
save_json(filename, parameters, result_dict)
def save_and_check_md5(data, parameters, filename, generate_golden=False):
def save_and_check_md5(data, filename, generate_golden=False):
"""
Save the dataset dictionary and compare (as dictionary) with golden file (md5).
Use create_dict_iterator to access the dataset.
......@@ -197,8 +190,9 @@ def ordered_save_and_check(data, parameters, filename, generate_golden=False):
compare_to_golden(golden_ref_dir, result_dict)
# Save to a json file for inspection
# save_json(filename, parameters, result_dict)
if SAVE_JSON:
# Save result to a json file for inspection
save_json(filename, parameters, result_dict)
def diff_mse(in1, in2):
......@@ -211,24 +205,18 @@ def diff_me(in1, in2):
return mse / 255 * 100
def diff_ssim(in1, in2):
from skimage.measure import compare_ssim as ssim
val = ssim(in1, in2, multichannel=True)
return (1 - val) * 100
def visualize(image_original, image_transformed):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_cropped)
num = len(image_transformed)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_cropped[i])
plt.imshow(image_transformed[i])
plt.title("Transformed image")
plt.show()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册