提交 0a95223f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1910 RandomColorAdjust error for grayscale images

Merge pull request !1910 from MahdiRahmaniHanzaki/I1J9SQ-random-color-adjust-bug
...@@ -376,8 +376,9 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) ...@@ -376,8 +376,9 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output)
*output = input; *output = input;
return Status::OK(); return Status::OK();
} }
if (input_cv->shape().Size() != 3 && input_cv->shape()[2] != 3) { if (input_cv->shape().Size() < 2 || input_cv->shape().Size() > 3 ||
RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels is not equal 3"); (input_cv->shape().Size() == 3 && input_cv->shape()[2] != 3 && input_cv->shape()[2] != 1)) {
RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3 nor 1");
} }
cv::Mat output_img; cv::Mat output_img;
...@@ -401,8 +402,8 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) ...@@ -401,8 +402,8 @@ Status HwcToChw(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output)
Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) { Status SwapRedAndBlue(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output) {
try { try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input)); std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
if (input_cv->shape().Size() != 3 && input_cv->shape()[2] != 3) { if (input_cv->shape().Size() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels is not equal 3"); RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3");
} }
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type()); auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv); RETURN_UNEXPECTED_IF_NULL(output_cv);
...@@ -422,7 +423,7 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso ...@@ -422,7 +423,7 @@ Status CropAndResize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
} }
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Ishape not <H,W,C> or <H,W>"); RETURN_STATUS_UNEXPECTED("Shape not <H,W,C> or <H,W>");
} }
// image too large or too small // image too large or too small
if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 || if (crop_height == 0 || crop_width == 0 || target_height == 0 || target_height > crop_height * 1000 ||
...@@ -541,8 +542,8 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te ...@@ -541,8 +542,8 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
if (!input_cv->mat().data) { if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
} }
if (input_cv->Rank() != 3 && input_cv->shape()[2] != 3) { if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,3> or <H,W>"); RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3");
} }
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type()); auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv); RETURN_UNEXPECTED_IF_NULL(output_cv);
...@@ -561,8 +562,8 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens ...@@ -561,8 +562,8 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
if (!input_cv->mat().data) { if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
} }
if (input_cv->Rank() != 3 && input_cv->shape()[2] != 3) { if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,3> or <H,W>"); RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3");
} }
cv::Mat gray, output_img; cv::Mat gray, output_img;
cv::cvtColor(input_img, gray, CV_RGB2GRAY); cv::cvtColor(input_img, gray, CV_RGB2GRAY);
...@@ -587,8 +588,8 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te ...@@ -587,8 +588,8 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
if (!input_cv->mat().data) { if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
} }
if (input_cv->Rank() != 3 && input_cv->shape()[2] != 3) { if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,3> or <H,W>"); RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3");
} }
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type()); auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv); RETURN_UNEXPECTED_IF_NULL(output_cv);
...@@ -615,8 +616,8 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> * ...@@ -615,8 +616,8 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
if (!input_cv->mat().data) { if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
} }
if (input_cv->Rank() != 3 && input_cv->shape()[2] != 3) { if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,3> or <H,W>"); RETURN_STATUS_UNEXPECTED("The shape is incorrect: number of channels does not equal 3");
} }
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type()); auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv); RETURN_UNEXPECTED_IF_NULL(output_cv);
...@@ -644,7 +645,7 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp ...@@ -644,7 +645,7 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
uint8_t fill_g, uint8_t fill_b) { uint8_t fill_g, uint8_t fill_b) {
try { try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input); std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) { if (input_cv->mat().data == nullptr || input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase");
} }
cv::Mat input_img = input_cv->mat(); cv::Mat input_img = input_cv->mat();
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" """
Testing RandomColorAdjust op in DE Testing RandomColorAdjust op in DE
""" """
import pytest
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from util import diff_mse from util import diff_mse
...@@ -46,71 +47,51 @@ def visualize(first, mse, second): ...@@ -46,71 +47,51 @@ def visualize(first, mse, second):
plt.show() plt.show()
def test_random_color_adjust_op_brightness(plot=False): def util_test_random_color_adjust_error(brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0)):
""" """
Test RandomColorAdjust op Util function that tests the error message in case of grayscale images
""" """
logger.info("test_random_color_adjust_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_adjust_op = c_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
ctrans = [decode_op,
random_adjust_op,
]
data1 = data1.map(input_columns=["image"], operations=ctrans)
# Second dataset
transforms = [ transforms = [
py_vision.Decode(), py_vision.Decode(),
py_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0)), py_vision.Grayscale(1),
py_vision.ToTensor(), py_vision.ToTensor(),
(lambda image: (image.transpose(1, 2, 0) * 255).astype(np.uint8))
] ]
transform = py_vision.ComposeOp(transforms)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=transform())
num_iter = 0 transform = py_vision.ComposeOp(transforms)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()): data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
num_iter += 1 data1 = data1.map(input_columns=["image"], operations=transform())
c_image = item1["image"]
py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
logger.info("shape of c_image: {}".format(c_image.shape))
logger.info("shape of py_image: {}".format(py_image.shape))
logger.info("dtype of c_image: {}".format(c_image.dtype)) # if input is grayscale, the output dimensions should be single channel, the following should fail
logger.info("dtype of py_image: {}".format(py_image.dtype)) random_adjust_op = c_vision.RandomColorAdjust(brightness=brightness, contrast=contrast, saturation=saturation,
hue=hue)
with pytest.raises(RuntimeError) as info:
data1 = data1.map(input_columns=["image"], operations=random_adjust_op)
dataset_shape_1 = []
for item1 in data1.create_dict_iterator():
c_image = item1["image"]
dataset_shape_1.append(c_image.shape)
mse = diff_mse(c_image, py_image) error_msg = "The shape is incorrect: number of channels does not equal 3"
logger.info("mse is {}".format(mse))
logger.info("random_rotation_op_{}, mse: {}".format(num_iter + 1, mse)) assert error_msg in str(info.value)
assert mse < 0.01
# if mse != 0:
# logger.info("mse is: {}".format(mse))
if plot:
visualize(c_image, mse, py_image)
def test_random_color_adjust_op_contrast(plot=False): def util_test_random_color_adjust_op(brightness=(1, 1), contrast=(1, 1), saturation=(1, 1), hue=(0, 0), plot=False):
""" """
Test RandomColorAdjust op Util function that tests RandomColorAdjust for a specific argument
""" """
logger.info("test_random_color_adjust_op")
# First dataset # First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode() decode_op = c_vision.Decode()
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0)) random_adjust_op = c_vision.RandomColorAdjust(brightness=brightness, contrast=contrast, saturation=saturation,
hue=hue)
ctrans = [decode_op, ctrans = [decode_op,
random_adjust_op random_adjust_op,
] ]
data1 = data1.map(input_columns=["image"], operations=ctrans) data1 = data1.map(input_columns=["image"], operations=ctrans)
...@@ -118,8 +99,9 @@ def test_random_color_adjust_op_contrast(plot=False): ...@@ -118,8 +99,9 @@ def test_random_color_adjust_op_contrast(plot=False):
# Second dataset # Second dataset
transforms = [ transforms = [
py_vision.Decode(), py_vision.Decode(),
py_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0)), py_vision.RandomColorAdjust(brightness=brightness, contrast=contrast, saturation=saturation,
py_vision.ToTensor(), hue=hue),
py_vision.ToTensor()
] ]
transform = py_vision.ComposeOp(transforms) transform = py_vision.ComposeOp(transforms)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
...@@ -136,161 +118,101 @@ def test_random_color_adjust_op_contrast(plot=False): ...@@ -136,161 +118,101 @@ def test_random_color_adjust_op_contrast(plot=False):
logger.info("dtype of c_image: {}".format(c_image.dtype)) logger.info("dtype of c_image: {}".format(c_image.dtype))
logger.info("dtype of py_image: {}".format(py_image.dtype)) logger.info("dtype of py_image: {}".format(py_image.dtype))
diff = c_image - py_image
logger.info("contrast difference c is : {}".format(c_image[0][0]))
logger.info("contrast difference py is : {}".format(py_image[0][0]))
diff = c_image - py_image
logger.info("contrast difference is : {}".format(diff[0][0]))
# mse = (np.sum(np.power(diff, 2))) / (c_image.shape[0] * c_image.shape[1])
mse = diff_mse(c_image, py_image) mse = diff_mse(c_image, py_image)
logger.info("mse is {}".format(mse)) logger.info("mse is {}".format(mse))
# assert mse < 0.01
# logger.info("random_rotation_op_{}, mse: {}".format(num_iter + 1, mse)) logger.info("random_rotation_op_{}, mse: {}".format(num_iter + 1, mse))
# if mse != 0: assert mse < 0.01
# logger.info("mse is: {}".format(mse))
if plot: if plot:
visualize(c_image, mse, py_image) visualize(c_image, mse, py_image)
def test_random_color_adjust_op_saturation(plot=False): def test_random_color_adjust_op_brightness(plot=False):
""" """
Test RandomColorAdjust op Test RandomColorAdjust op for brightness
""" """
logger.info("test_random_color_adjust_op")
# First dataset logger.info("test_random_color_adjust_op_brightness")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0)) util_test_random_color_adjust_op(brightness=(0.5, 0.5), plot=plot)
ctrans = [decode_op,
random_adjust_op
]
data1 = data1.map(input_columns=["image"], operations=ctrans) def test_random_color_adjust_op_brightness_error():
"""
Test RandomColorAdjust error message with brightness input in case of grayscale image
"""
# Second dataset logger.info("test_random_color_adjust_op_brightness_error")
transforms = [
py_vision.Decode(),
py_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0)),
py_vision.ToTensor(),
]
transform = py_vision.ComposeOp(transforms)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=transform())
num_iter = 0 util_test_random_color_adjust_error(brightness=(0.5, 0.5))
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
num_iter += 1
c_image = item1["image"]
py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
logger.info("shape of c_image: {}".format(c_image.shape)) def test_random_color_adjust_op_contrast(plot=False):
logger.info("shape of py_image: {}".format(py_image.shape)) """
Test RandomColorAdjust op for contrast
"""
logger.info("dtype of c_image: {}".format(c_image.dtype)) logger.info("test_random_color_adjust_op_contrast")
logger.info("dtype of py_image: {}".format(py_image.dtype))
mse = diff_mse(c_image, py_image) util_test_random_color_adjust_op(contrast=(0.5, 0.5), plot=plot)
logger.info("mse is {}".format(mse))
assert mse < 0.01
# logger.info("random_rotation_op_{}, mse: {}".format(num_iter + 1, mse))
# if mse != 0:
# logger.info("mse is: {}".format(mse))
if plot:
visualize(c_image, mse, py_image)
def test_random_color_adjust_op_hue(plot=False): def test_random_color_adjust_op_contrast_error():
""" """
Test RandomColorAdjust op Test RandomColorAdjust error message with contrast input in case of grayscale image
""" """
logger.info("test_random_color_adjust_op")
# First dataset logger.info("test_random_color_adjust_op_contrast_error")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)) util_test_random_color_adjust_error(contrast=(0.5, 0.5))
ctrans = [decode_op,
random_adjust_op,
]
data1 = data1.map(input_columns=["image"], operations=ctrans) def test_random_color_adjust_op_saturation(plot=False):
"""
Test RandomColorAdjust op for saturation
"""
logger.info("test_random_color_adjust_op_saturation")
# Second dataset util_test_random_color_adjust_op(saturation=(0.5, 0.5), plot=plot)
transforms = [
py_vision.Decode(),
py_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)),
py_vision.ToTensor(),
]
transform = py_vision.ComposeOp(transforms)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(input_columns=["image"], operations=transform())
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
num_iter += 1
c_image = item1["image"]
py_image = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
# logger.info("shape of img: {}".format(img.shape)) def test_random_color_adjust_op_saturation_error():
logger.info("shape of c_image: {}".format(c_image.shape)) """
logger.info("shape of py_image: {}".format(py_image.shape)) Test RandomColorAdjust error message with saturation input in case of grayscale image
"""
logger.info("dtype of c_image: {}".format(c_image.dtype)) logger.info("test_random_color_adjust_op_saturation_error")
logger.info("dtype of py_image: {}".format(py_image.dtype))
# logger.info("dtype of img: {}".format(img.dtype))
# mse = (np.sum(np.power(diff, 2))) / (c_image.shape[0] * c_image.shape[1]) util_test_random_color_adjust_error(saturation=(0.5, 0.5))
mse = diff_mse(c_image, py_image)
logger.info("mse is {}".format(mse))
assert mse < 0.01
if plot:
visualize(c_image, mse, py_image)
# pylint: disable=unnecessary-lambda def test_random_color_adjust_op_hue(plot=False):
def test_random_color_adjust_grayscale():
""" """
Tests that the random color adjust works for grayscale images Test RandomColorAdjust op for hue
""" """
logger.info("test_random_color_adjust_op_hue")
def channel_swap(image): util_test_random_color_adjust_op(hue=(0.5, 0.5), plot=plot)
"""
Py func hack for our pytransforms to work with c transforms
"""
return (image.transpose(1, 2, 0) * 255).astype(np.uint8)
transforms = [
py_vision.Decode(),
py_vision.Grayscale(1),
py_vision.ToTensor(),
(lambda image: channel_swap(image))
]
transform = py_vision.ComposeOp(transforms) def test_random_color_adjust_op_hue_error():
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) """
data1 = data1.map(input_columns=["image"], operations=transform()) Test RandomColorAdjust error message with hue input in case of grayscale image
"""
# if input is grayscale, the output dimensions should be single channel, the following should fail logger.info("test_random_color_adjust_op_hue_error")
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2))
try: util_test_random_color_adjust_error(hue=(0.5, 0.5))
data1 = data1.map(input_columns=["image"], operations=random_adjust_op)
dataset_shape_1 = []
for item1 in data1.create_dict_iterator():
c_image = item1["image"]
dataset_shape_1.append(c_image.shape)
except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))
if __name__ == "__main__": if __name__ == "__main__":
test_random_color_adjust_op_brightness() test_random_color_adjust_op_brightness(plot=True)
test_random_color_adjust_op_contrast() test_random_color_adjust_op_brightness_error()
test_random_color_adjust_op_saturation() test_random_color_adjust_op_contrast(plot=True)
test_random_color_adjust_op_hue() test_random_color_adjust_op_contrast_error()
test_random_color_adjust_grayscale() test_random_color_adjust_op_saturation(plot=True)
test_random_color_adjust_op_saturation_error()
test_random_color_adjust_op_hue(plot=True)
test_random_color_adjust_op_hue_error()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册