diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index 08016ee061302c3fa9fd8118c78628bb12c2f4a7..457ff8b1b5ebd6ac8414ea483a2a9847b7e4195e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -53,6 +53,7 @@ #include "minddata/dataset/kernels/image/center_crop_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/equalize_op.h" #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/invert_op.h" @@ -374,6 +375,10 @@ void bindTensorOps1(py::module *m) { .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), py::arg("stdR"), py::arg("stdG"), py::arg("stdB")); + (void)py::class_>( + *m, "EqualizeOp", "Tensor operation to apply histogram equalization on images.") + .def(py::init<>()); + (void)py::class_>(*m, "InvertOp", "Tensor operation to apply invert on RGB images.") .def(py::init<>()); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 743fc83c1493e287d83a451e39cd071e3eb2cc14..a7777302847bdc7e91311513dd5e4f69b2e08473 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(kernels-image OBJECT center_crop_op.cc cut_out_op.cc decode_op.cc + equalize_op.cc hwc_to_chw_op.cc image_utils.cc invert_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e5bf0fd6282412e7f63b2884398ee026aee2b6f9 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/equalize_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { + +// only supports RGB images + +Status EqualizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return Equalize(input, output); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9fd030f5852d396ca344a373f182893851ad022d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/equalize_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class EqualizeOp : public TensorOp { + public: + EqualizeOp() {} + ~EqualizeOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { out << Name(); } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kEqualizeOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_EQUALIZE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index dac076a5f435760f7c3e796a6d5ea058c847425c..97e735256404e6c4f0dba1293b11c19a6eefba1a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -749,6 +749,46 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // For greyscale images, extend dimension if rank is 2 and reshape output to be of rank 2. + if (input_cv->Rank() == 2) { + RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); + } + // Get number of channels and image matrix + std::size_t num_of_channels = input_cv->shape()[2]; + if (num_of_channels != 1 && num_of_channels != 3) { + RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); + } + cv::Mat image = input_cv->mat(); + // Separate the image to channels + std::vector planes(num_of_channels); + cv::split(image, planes); + // Equalize each channel separately + std::vector image_result; + for (std::size_t layer = 0; layer < planes.size(); layer++) { + cv::Mat channel_result; + cv::equalizeHist(planes[layer], channel_result); + image_result.push_back(channel_result); + } + cv::Mat result; + cv::merge(image_result, result); + std::shared_ptr output_cv = std::make_shared(result); + if (input_cv->Rank() == 2) output_cv->Squeeze(); + (*output) = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in equalize."); + } + return Status::OK(); +} + Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index c1426338954fc7cb3fb7f1436a52ff4b7767bed1..9a90bec61eb8a833c661d78b69f6274f8c5dcbf4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -200,6 +200,12 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &hue); +/// \brief Returns image with equalized histogram. +/// \param[in] input: Tensor of shape // in RGB/Grayscale and +/// any OpenCv compatible type, see CVTensor. +/// \param[out] output: Equalized image of same shape and type. +Status Equalize(const std::shared_ptr &input, std::shared_ptr *output); + // Masks out a random section from the image with set dimension // @param input: input Tensor // @param output: cutOut Tensor diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 638ce49dbf59ca381d3e8d524b7290199e0ff432..9c12759422e85148957fdeeab208a3e4e37aedb2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -92,6 +92,7 @@ constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kCenterCropOp[] = "CenterCropOp"; constexpr char kCutOutOp[] = "CutOutOp"; +constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp"; constexpr char kInvertOp[] = "InvertOp"; constexpr char kNormalizeOp[] = "NormalizeOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 9a73a9007395fd3a02c4e7627dd49e25cbe5cd5d..3c2e7aeecb36e1c2676a383ed5b4eb2f92d61904 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -89,6 +89,13 @@ class AutoContrast(cde.AutoContrastOp): super().__init__(cutoff, ignore) +class Equalize(cde.EqualizeOp): + """ + Apply histogram equalization on input image. + does not have input arguments. + """ + + class Invert(cde.InvertOp): """ Apply invert on input image in RGB mode. diff --git a/tests/ut/data/dataset/golden/equalize_01_result_c.npz b/tests/ut/data/dataset/golden/equalize_01_result_c.npz new file mode 100644 index 0000000000000000000000000000000000000000..2c3a37eb4dcf3b010e926530e3ca11a3e74c5e53 Binary files /dev/null and b/tests/ut/data/dataset/golden/equalize_01_result_c.npz differ diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py index 0a5f2f93d50f03dd236beacd9fff177cfd07b819..26102ae809fb66199c06d3b3b877aa534c406765 100644 --- a/tests/ut/python/dataset/test_equalize.py +++ b/tests/ut/python/dataset/test_equalize.py @@ -18,6 +18,7 @@ Testing Equalize op in DE import numpy as np import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.c_transforms as C import mindspore.dataset.transforms.vision.py_transforms as F from mindspore import log as logger from util import visualize_list, diff_mse, save_and_check_md5 @@ -26,9 +27,9 @@ DATA_DIR = "../data/dataset/testImageNetData/train/" GENERATE_GOLDEN = False -def test_equalize(plot=False): +def test_equalize_py(plot=False): """ - Test Equalize + Test Equalize py op """ logger.info("Test Equalize") @@ -83,9 +84,141 @@ def test_equalize(plot=False): visualize_list(images_original, images_equalize) -def test_equalize_md5(): +def test_equalize_c(plot=False): """ - Test Equalize with md5 check + Test Equalize Cpp op + """ + logger.info("Test Equalize cpp op") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = [C.Decode(), C.Resize(size=[224, 224])] + + ds_original = ds.map(input_columns="image", + operations=transforms_original) + + ds_original = ds_original.batch(512) + + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, + image, + axis=0) + + # Equalize Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_equalize = [C.Decode(), C.Resize(size=[224, 224]), + C.Equalize()] + + ds_equalize = ds.map(input_columns="image", + operations=transform_equalize) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_equalize): + if idx == 0: + images_equalize = image + else: + images_equalize = np.append(images_equalize, + image, + axis=0) + if plot: + visualize_list(images_original, images_equalize) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_equalize[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_equalize_py_c(plot=False): + """ + Test Equalize Cpp op and python op + """ + logger.info("Test Equalize cpp and python op") + + # equalize Images in cpp + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + ds_c_equalize = ds.map(input_columns="image", + operations=C.Equalize()) + + ds_c_equalize = ds_c_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_c_equalize): + if idx == 0: + images_c_equalize = image + else: + images_c_equalize = np.append(images_c_equalize, + image, + axis=0) + + # Equalize images in python + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), C.Resize((224, 224))]) + + transforms_p_equalize = F.ComposeOp([lambda img: img.astype(np.uint8), + F.ToPIL(), + F.Equalize(), + np.array]) + + ds_p_equalize = ds.map(input_columns="image", + operations=transforms_p_equalize()) + + ds_p_equalize = ds_p_equalize.batch(512) + + for idx, (image, _) in enumerate(ds_p_equalize): + if idx == 0: + images_p_equalize = image + else: + images_p_equalize = np.append(images_p_equalize, + image, + axis=0) + + num_samples = images_c_equalize.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_p_equalize[i], images_c_equalize[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize_list(images_c_equalize, images_p_equalize, visualize_mode=2) + + +def test_equalize_one_channel(): + """ + Test Equalize cpp op with one channel image + """ + logger.info("Test Equalize C Op With One Channel Images") + + c_op = C.Equalize() + + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds.map(input_columns="image", + operations=c_op) + + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "The shape" in str(e) + + +def test_equalize_md5_py(): + """ + Test Equalize py op with md5 check """ logger.info("Test Equalize") @@ -101,6 +234,31 @@ def test_equalize_md5(): save_and_check_md5(data1, filename, generate_golden=GENERATE_GOLDEN) +def test_equalize_md5_c(): + """ + Test Equalize cpp op with md5 check + """ + logger.info("Test Equalize cpp op with md5 check") + + # Generate dataset + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = [C.Decode(), + C.Resize(size=[224, 224]), + C.Equalize(), + F.ToTensor()] + + data = ds.map(input_columns="image", operations=transforms_equalize) + # Compare with expected md5 from images + filename = "equalize_01_result_c.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + if __name__ == "__main__": - test_equalize(plot=True) - test_equalize_md5() + test_equalize_py(plot=False) + test_equalize_c(plot=False) + test_equalize_py_c(plot=False) + test_equalize_one_channel() + test_equalize_md5_py() + test_equalize_md5_c() + \ No newline at end of file