test_random_solarize_op.py 4.9 KB
Newer Older
N
nhussain 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RandomSolarizeOp op in DE
"""
import pytest
import mindspore.dataset as ds
N
nhussain 已提交
20
import mindspore.dataset.engine as de
N
nhussain 已提交
21 22
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
N
nhussain 已提交
23 24
from util import visualize_list, save_and_check_md5, config_get_set_seed, config_get_set_num_parallel_workers, \
    visualize_one_channel_dataset
N
nhussain 已提交
25 26 27

GENERATE_GOLDEN = False

N
nhussain 已提交
28
MNIST_DATA_DIR = "../data/dataset/testMnistData"
N
nhussain 已提交
29 30 31 32
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"


N
nhussain 已提交
33
def test_random_solarize_op(threshold=(10, 150), plot=False, run_golden=True):
N
nhussain 已提交
34 35 36 37 38 39
    """
    Test RandomSolarize
    """
    logger.info("Test RandomSolarize")

    # First dataset
N
nhussain 已提交
40
    data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
N
nhussain 已提交
41 42
    decode_op = vision.Decode()

N
nhussain 已提交
43 44 45
    original_seed = config_get_set_seed(0)
    original_num_parallel_workers = config_get_set_num_parallel_workers(1)

N
nhussain 已提交
46 47 48 49
    if threshold is None:
        solarize_op = vision.RandomSolarize()
    else:
        solarize_op = vision.RandomSolarize(threshold)
N
nhussain 已提交
50

N
nhussain 已提交
51 52 53 54
    data1 = data1.map(input_columns=["image"], operations=decode_op)
    data1 = data1.map(input_columns=["image"], operations=solarize_op)

    # Second dataset
N
nhussain 已提交
55
    data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
N
nhussain 已提交
56 57
    data2 = data2.map(input_columns=["image"], operations=decode_op)

N
nhussain 已提交
58 59 60 61
    if run_golden:
        filename = "random_solarize_01_result.npz"
        save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN)

N
nhussain 已提交
62 63
    image_solarized = []
    image = []
N
nhussain 已提交
64

65
    for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), data2.create_dict_iterator(num_epochs=1)):
N
nhussain 已提交
66 67 68 69 70
        image_solarized.append(item1["image"].copy())
        image.append(item2["image"].copy())
    if plot:
        visualize_list(image, image_solarized)

N
nhussain 已提交
71 72 73
    ds.config.set_seed(original_seed)
    ds.config.set_num_parallel_workers(original_num_parallel_workers)

N
nhussain 已提交
74

N
nhussain 已提交
75
def test_random_solarize_mnist(plot=False, run_golden=True):
N
nhussain 已提交
76
    """
N
nhussain 已提交
77
    Test RandomSolarize op with MNIST dataset (Grayscale images)
N
nhussain 已提交
78 79
    """

N
nhussain 已提交
80 81 82
    mnist_1 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
    mnist_2 = de.MnistDataset(dataset_dir=MNIST_DATA_DIR, num_samples=2, shuffle=False)
    mnist_2 = mnist_2.map(input_columns="image", operations=vision.RandomSolarize((0, 255)))
N
nhussain 已提交
83

N
nhussain 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
    images = []
    images_trans = []
    labels = []

    for _, (data_orig, data_trans) in enumerate(zip(mnist_1, mnist_2)):
        image_orig, label_orig = data_orig
        image_trans, _ = data_trans
        images.append(image_orig)
        labels.append(label_orig)
        images_trans.append(image_trans)

    if plot:
        visualize_one_channel_dataset(images, images_trans, labels)

    if run_golden:
        filename = "random_solarize_02_result.npz"
        save_and_check_md5(mnist_2, filename, generate_golden=GENERATE_GOLDEN)
N
nhussain 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128


def test_random_solarize_errors():
    """
    Test that RandomSolarize errors with bad input
    """
    with pytest.raises(ValueError) as error_info:
        vision.RandomSolarize((12, 1))
    assert "threshold must be in min max format numbers" in str(error_info.value)

    with pytest.raises(ValueError) as error_info:
        vision.RandomSolarize((12, 1000))
    assert "Input is not within the required interval of (0 to 255)." in str(error_info.value)

    with pytest.raises(TypeError) as error_info:
        vision.RandomSolarize((122.1, 140))
    assert "Argument threshold[0] with value 122.1 is not of type (<class 'int'>,)." in str(error_info.value)

    with pytest.raises(ValueError) as error_info:
        vision.RandomSolarize((122, 100, 30))
    assert "threshold must be a sequence of two numbers" in str(error_info.value)

    with pytest.raises(ValueError) as error_info:
        vision.RandomSolarize((120,))
    assert "threshold must be a sequence of two numbers" in str(error_info.value)


if __name__ == "__main__":
N
nhussain 已提交
129 130 131 132
    test_random_solarize_op((10, 150), plot=True, run_golden=True)
    test_random_solarize_op((12, 120), plot=True, run_golden=False)
    test_random_solarize_op(plot=True, run_golden=False)
    test_random_solarize_mnist(plot=True, run_golden=True)
N
nhussain 已提交
133
    test_random_solarize_errors()