diff --git a/mindspore/dataset/core/configuration.py b/mindspore/dataset/core/configuration.py index c08f47526e4acfb74f19d3729376010b3e4af7e1..38b25368b346435927c97f6452afac49ebd1e84e 100644 --- a/mindspore/dataset/core/configuration.py +++ b/mindspore/dataset/core/configuration.py @@ -16,6 +16,7 @@ The configuration manager. """ import random +import numpy import mindspore._c_dataengine as cde INT32_MAX = 2147483647 @@ -33,10 +34,10 @@ class ConfigurationManager: Set the seed to be used in any random generator. This is used to produce deterministic results. Note: - This set_seed function sets the seed in the python random library function for deterministic - python augmentations using randomness. This set_seed function should be called with every - iterator created to reset the random seed. In our pipeline this does not guarantee - deterministic results with num_parallel_workers > 1. + This set_seed function sets the seed in the python random library and numpy.random library + for deterministic python augmentations using randomness. This set_seed function should + be called with every iterator created to reset the random seed. In our pipeline this + does not guarantee deterministic results with num_parallel_workers > 1. Args: seed(int): seed to be set @@ -54,6 +55,8 @@ class ConfigurationManager: raise ValueError("Seed given is not within the required range") self.config.set_seed(seed) random.seed(seed) + # numpy.random isn't thread safe + numpy.random.seed(seed) def get_seed(self): """ diff --git a/tests/ut/data/dataset/golden/test_center_crop_01_result.npz b/tests/ut/data/dataset/golden/test_center_crop_01_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..2f8b2273df67b4d4ae7db8d3584eab79464b40c6 Binary files /dev/null and b/tests/ut/data/dataset/golden/test_center_crop_01_result.npz differ diff --git a/tests/ut/python/dataset/test_center_crop.py b/tests/ut/python/dataset/test_center_crop.py index 596c1e1c722292a9d3dd22018751b543fb487437..62d4ebb0429e66c4136e6ebb2aed433c78adf6b2 100644 --- a/tests/ut/python/dataset/test_center_crop.py +++ b/tests/ut/python/dataset/test_center_crop.py @@ -13,53 +13,96 @@ # limitations under the License. # ============================================================================== import mindspore.dataset.transforms.vision.c_transforms as vision +import mindspore.dataset.transforms.vision.py_transforms as py_vision import numpy as np import matplotlib.pyplot as plt import mindspore.dataset as ds from mindspore import log as logger +from util import diff_mse, visualize, save_and_check_md5 + +GENERATE_GOLDEN = False DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" -def visualize(image_original, image_cropped): +def test_center_crop_op(height=375, width=375, plot=False): """ - visualizes the image using DE op and Numpy op + Test random_vertical """ - num = len(image_cropped) - for i in range(num): - plt.subplot(2, num, i + 1) - plt.imshow(image_original[i]) - plt.title("Original image") + logger.info("Test CenterCrop") + + # First dataset + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + decode_op = vision.Decode() + # 3 images [375, 500] [600, 500] [512, 512] + center_crop_op = vision.CenterCrop([height, width]) + data1 = data1.map(input_columns=["image"], operations=decode_op) + data1 = data1.map(input_columns=["image"], operations=center_crop_op) - plt.subplot(2, num, i + num + 1) - plt.imshow(image_cropped[i]) - plt.title("DE center_crop image") + # Second dataset + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + data2 = data2.map(input_columns=["image"], operations=decode_op) - plt.show() + image_cropped = [] + image = [] + for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + image_cropped.append(item1["image"].copy()) + image.append(item2["image"].copy()) + if plot: + visualize(image, image_cropped) -def test_center_crop_op(height=375, width=375, plot=False): +def test_center_crop_md5(height=375, width=375): """ Test random_vertical """ logger.info("Test CenterCrop") # First dataset - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle =False) decode_op = vision.Decode() # 3 images [375, 500] [600, 500] [512, 512] - center_crop_op = vision.CenterCrop(height, width) + center_crop_op = vision.CenterCrop([height, width]) + data1 = data1.map(input_columns=["image"], operations=decode_op) + data1 = data1.map(input_columns=["image"], operations=center_crop_op) + # expected md5 from images + + filename = "test_center_crop_01_result.npz" + parameters = {"params": {}} + save_and_check_md5(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + +def test_center_crop_comp(height=375, width=375, plot=False): + """ + Test random_vertical between python and c image augmentation + """ + logger.info("Test CenterCrop") + + # First dataset + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + decode_op = vision.Decode() + center_crop_op = vision.CenterCrop([height, width]) data1 = data1.map(input_columns=["image"], operations=decode_op) data1 = data1.map(input_columns=["image"], operations=center_crop_op) # Second dataset - data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) - data2 = data2.map(input_columns=["image"], operations=decode_op) + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) + transforms = [ + py_vision.Decode(), + py_vision.CenterCrop([height, width]), + py_vision.ToTensor() + ] + transform = py_vision.ComposeOp(transforms) + data2 = data2.map(input_columns=["image"], operations=transform()) image_cropped = [] image = [] for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): + c_image = item1["image"] + py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8) + # the images aren't exactly the same due to rouding error + assert (diff_mse(py_image, c_image) < 0.001) image_cropped.append(item1["image"].copy()) image.append(item2["image"].copy()) if plot: @@ -67,7 +110,8 @@ def test_center_crop_op(height=375, width=375, plot=False): if __name__ == "__main__": - test_center_crop_op() test_center_crop_op(600, 600) test_center_crop_op(300, 600) test_center_crop_op(600, 300) + test_center_crop_md5(600, 600) + test_center_crop_comp() \ No newline at end of file diff --git a/tests/ut/python/dataset/test_random_color_adjust.py b/tests/ut/python/dataset/test_random_color_adjust.py index c3e7bd3d7cab326deb662982ad4dc9f9802e9862..8a650e2315543c730e3ad2fd4fe990f2f1428d80 100644 --- a/tests/ut/python/dataset/test_random_color_adjust.py +++ b/tests/ut/python/dataset/test_random_color_adjust.py @@ -256,9 +256,6 @@ def test_random_color_adjust_op_hue(): mse = diff_mse(c_image, py_image) logger.info("mse is {}".format(mse)) assert mse < 0.01 - # logger.info("random_rotation_op_{}, mse: {}".format(num_iter + 1, mse)) - # if mse != 0: - # logger.info("mse is: {}".format(mse)) # Uncomment below line if you want to visualize images # visualize(c_image, mse, py_image) diff --git a/tests/ut/python/dataset/util.py b/tests/ut/python/dataset/util.py index d2f35cee630b40893ca5ec81a9536df35f9cd567..753528d3f872a35d5f67b33e570a8b0a54b254b1 100644 --- a/tests/ut/python/dataset/util.py +++ b/tests/ut/python/dataset/util.py @@ -16,6 +16,9 @@ import json import os import numpy as np +import matplotlib.pyplot as plt +import hashlib + #import jsbeautifier from mindspore import log as logger @@ -41,6 +44,15 @@ def save_golden_dict(cur_dir, golden_ref_dir, result_dict): np.savez(golden_ref_dir, np.array(list(result_dict.items()))) +def save_golden_md5(cur_dir, golden_ref_dir, result_dict): + """ + Save the dictionary (both keys and values) as the golden result in .npz file + """ + logger.info("cur_dir is {}".format(cur_dir)) + logger.info("golden_ref_dir is {}".format(golden_ref_dir)) + np.savez(golden_ref_dir, np.array(list(result_dict.items()))) + + def compare_to_golden(golden_ref_dir, result_dict): """ Compare as numpy arrays the test result to the golden result @@ -55,7 +67,8 @@ def compare_to_golden_dict(golden_ref_dir, result_dict): Compare as dictionaries the test result to the golden result """ golden_array = np.load(golden_ref_dir, allow_pickle=True)['arr_0'] - assert result_dict == dict(golden_array) + np.testing.assert_equal (result_dict, dict(golden_array)) + # assert result_dict == dict(golden_array) def save_json(filename, parameters, result_dict): @@ -131,6 +144,33 @@ def save_and_check_dict(data, parameters, filename, generate_golden=False): # save_json(filename, parameters, result_dict) +def save_and_check_md5(data, parameters, filename, generate_golden=False): + """ + Save the dataset dictionary and compare (as dictionary) with golden file (md5). + Use create_dict_iterator to access the dataset. + """ + num_iter = 0 + result_dict = {} + + for item in data.create_dict_iterator(): # each data is a dictionary + for data_key in list(item.keys()): + if data_key not in result_dict: + result_dict[data_key] = [] + # save the md5 as numpy array + result_dict[data_key].append(np.frombuffer(hashlib.md5(item[data_key]).digest(), dtype='