提交 dfc09701 编写于 作者: M Mahdi

Added TenCrop test

Added unit tests for both testing the functinality of the TenCrop and
its error messages. Due to the similarity of this method to FiveCrop the
test cases are similar to FiveCrop test cases.
Signed-off-by: NMahdi <mahdi.rahmani.hanzaki@huawei.com>

added error_msg function call in the main method

refactored the test and added visual representation of the results

Separated the two error cases into two different functions and used the
visualize function in util.py to plot the result of TenCrop.
Signed-off-by: NMahdi <mahdi.rahmani.hanzaki@huawei.com>

Added new test cases

Added new test cases including test case for checking the error message
when the size variable is not a positive integer, test case for
rectangle crop, test case for vertical flip setting, and testing for
similarity of the result of TenCrop for the same input data in different runs.
Signed-off-by: NMahdi <mahdi.rahmani.hanzaki@huawei.com>

changed visualize in test_five_crop

Changed the visualize function in test_five_crop to use the already
existing function in util.py
Signed-off-by: NMahdi <mahdi.rahmani.hanzaki@huawei.com>

made generate_golden variable global
上级 4ce1cf45
......@@ -21,29 +21,13 @@ import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.py_transforms as vision
from mindspore import log as logger
from util import visualize
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def visualize(image_1, image_2):
"""
visualizes the image using FiveCrop
"""
plt.subplot(161)
plt.imshow(image_1)
plt.title("Original")
for i, image in enumerate(image_2):
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)
plt.subplot(162 + i)
plt.imshow(image)
plt.title("image {} in FiveCrop".format(i + 1))
plt.show()
def test_five_crop_op():
def test_five_crop_op(plot=False):
"""
Test FiveCrop
"""
......@@ -79,8 +63,8 @@ def test_five_crop_op():
logger.info("dtype of image_1: {}".format(image_1.dtype))
logger.info("dtype of image_2: {}".format(image_2.dtype))
# visualize(image_1, image_2)
if plot:
visualize(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
# The output data should be of a 4D tensor shape, a stack of 5 images.
assert len(image_2.shape) == 4
......@@ -111,5 +95,5 @@ def test_five_crop_error_msg():
if __name__ == "__main__":
test_five_crop_op()
test_five_crop_op(plot=True)
test_five_crop_error_msg()
# Copyright 2020 Huawei Technologies Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Testing TenCrop in DE
"""
import pytest
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.py_transforms as vision
from util import visualize, save_and_check_md5
from mindspore import log as logger
GENERATE_GOLDEN = False
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def util_test_ten_crop(crop_size, vertical_flip=False, plot=False):
"""
Utility function for testing TenCrop. Input arguments are given by other tests
"""
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
transforms_1 = [
vision.Decode(),
vision.ToTensor(),
]
transform_1 = vision.ComposeOp(transforms_1)
data1 = data1.map(input_columns=["image"], operations=transform_1())
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
transforms_2 = [
vision.Decode(),
vision.TenCrop(crop_size, use_vertical_flip=vertical_flip),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
transform_2 = vision.ComposeOp(transforms_2)
data2 = data2.map(input_columns=["image"], operations=transform_2())
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
num_iter += 1
image_1 = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
image_2 = item2["image"]
logger.info("shape of image_1: {}".format(image_1.shape))
logger.info("shape of image_2: {}".format(image_2.shape))
logger.info("dtype of image_1: {}".format(image_1.dtype))
logger.info("dtype of image_2: {}".format(image_2.dtype))
if plot:
visualize(np.array([image_1]*10), (image_2 * 255).astype(np.uint8).transpose(0, 2, 3, 1))
# The output data should be of a 4D tensor shape, a stack of 10 images.
assert len(image_2.shape) == 4
assert image_2.shape[0] == 10
def test_ten_crop_op_square(plot=False):
"""
Tests TenCrop for a square crop
"""
logger.info("test_ten_crop_op_square")
util_test_ten_crop(200, plot=plot)
def test_ten_crop_op_rectangle(plot=False):
"""
Tests TenCrop for a rectangle crop
"""
logger.info("test_ten_crop_op_rectangle")
util_test_ten_crop((200, 150), plot=plot)
def test_ten_crop_op_vertical_flip(plot=False):
"""
Tests TenCrop with vertical flip set to True
"""
logger.info("test_ten_crop_op_vertical_flip")
util_test_ten_crop(200, vertical_flip=True, plot=plot)
def test_ten_crop_md5():
"""
Tests TenCrops for giving the same results in multiple runs.
Since TenCrop is a deterministic function, we expect it to return the same result for a specific input every time
"""
logger.info("test_ten_crop_md5")
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
transforms_2 = [
vision.Decode(),
vision.TenCrop((200, 100), use_vertical_flip=True),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
transform_2 = vision.ComposeOp(transforms_2)
data2 = data2.map(input_columns=["image"], operations=transform_2())
# Compare with expected md5 from images
filename = "ten_crop_01_result.npz"
save_and_check_md5(data2, filename, generate_golden=GENERATE_GOLDEN)
def test_ten_crop_list_size_error_msg():
"""
Tests TenCrop error message when the size arg has more than 2 elements
"""
logger.info("test_ten_crop_list_size_error_msg")
with pytest.raises(TypeError) as info:
transforms = [
vision.Decode(),
vision.TenCrop([200, 200, 200]),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
error_msg = "Size should be a single integer or a list/tuple (h, w) of length 2."
assert error_msg == str(info.value)
def test_ten_crop_invalid_size_error_msg():
"""
Tests TenCrop error message when the size arg is not positive
"""
logger.info("test_ten_crop_invalid_size_error_msg")
with pytest.raises(ValueError) as info:
transforms = [
vision.Decode(),
vision.TenCrop(0),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
error_msg = "Input is not within the required range"
assert error_msg == str(info.value)
with pytest.raises(ValueError) as info:
transforms = [
vision.Decode(),
vision.TenCrop(-10),
lambda images: np.stack([vision.ToTensor()(image) for image in images]) # 4D stack of 10 images
]
assert error_msg == str(info.value)
def test_ten_crop_wrong_img_error_msg():
"""
Tests TenCrop error message when the image is not in the correct format.
"""
logger.info("test_ten_crop_wrong_img_error_msg")
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
transforms = [
vision.Decode(),
vision.TenCrop(200),
vision.ToTensor()
]
transform = vision.ComposeOp(transforms)
data = data.map(input_columns=["image"], operations=transform())
with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator().get_next()
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor()
assert error_msg in str(info.value)
if __name__ == "__main__":
test_ten_crop_op_square(plot=True)
test_ten_crop_op_rectangle(plot=True)
test_ten_crop_op_vertical_flip(plot=True)
test_ten_crop_md5()
test_ten_crop_list_size_error_msg()
test_ten_crop_invalid_size_error_msg()
test_ten_crop_wrong_img_error_msg()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册