From 5647889c0dfe7d71d3a4d6c7b6ce9825088c8714 Mon Sep 17 00:00:00 2001 From: islam_amin Date: Sun, 12 Jul 2020 12:06:56 -0400 Subject: [PATCH] Added AutoContrast Op --- .../minddata/dataset/api/python_bindings.cc | 6 + .../dataset/kernels/image/CMakeLists.txt | 1 + .../dataset/kernels/image/auto_contrast_op.cc | 34 ++++ .../dataset/kernels/image/auto_contrast_op.h | 61 +++++++ .../dataset/kernels/image/image_utils.cc | 103 ++++++++++++ .../dataset/kernels/image/image_utils.h | 8 + .../minddata/dataset/kernels/tensor_op.h | 1 + .../dataset/transforms/vision/c_transforms.py | 20 ++- .../dataset/transforms/vision/validators.py | 21 +++ .../golden/autcontrast_01_result_py.npz | Bin 0 -> 607 bytes tests/ut/python/dataset/test_autocontrast.py | 159 +++++++++++++++++- 11 files changed, 408 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h create mode 100644 tests/ut/data/dataset/golden/autcontrast_01_result_py.npz diff --git a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc index a20c5c80c..b5a6dc59e 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python_bindings.cc @@ -48,6 +48,7 @@ #include "minddata/dataset/kernels/data/slice_op.h" #include "minddata/dataset/kernels/data/to_float16_op.h" #include "minddata/dataset/kernels/data/type_cast_op.h" +#include "minddata/dataset/kernels/image/auto_contrast_op.h" #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" #include "minddata/dataset/kernels/image/center_crop_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" @@ -362,6 +363,11 @@ void bindTensorOps1(py::module *m) { (void)py::class_>(*m, "TensorOp") .def("__deepcopy__", [](py::object &t, py::dict memo) { return t; }); + (void)py::class_>( + *m, "AutoContrastOp", "Tensor operation to apply autocontrast on an image.") + .def(py::init>(), py::arg("cutoff") = AutoContrastOp::kCutOff, + py::arg("ignore") = AutoContrastOp::kIgnore); + (void)py::class_>( *m, "NormalizeOp", "Tensor operation to normalize an image. Takes mean and std.") .def(py::init(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"), diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 402989af0..743fc83c1 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -1,6 +1,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(kernels-image OBJECT + auto_contrast_op.cc center_crop_op.cc cut_out_op.cc decode_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc new file mode 100644 index 000000000..417d16783 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "minddata/dataset/kernels/image/auto_contrast_op.h" +#include "minddata/dataset/kernels/image/image_utils.h" + +namespace mindspore { +namespace dataset { + +const float AutoContrastOp::kCutOff = 0.0; +const std::vector AutoContrastOp::kIgnore = {}; + +Status AutoContrastOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return AutoContrast(input, output, cutoff_, ignore_); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h new file mode 100644 index 000000000..94b3b23df --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/auto_contrast_op.h @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ +#define DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class AutoContrastOp : public TensorOp { + public: + /// Default cutoff to be used + static const float kCutOff; + /// Default ignore to be used + static const std::vector kIgnore; + + AutoContrastOp(const float &cutoff, const std::vector &ignore) : cutoff_(cutoff), ignore_(ignore) {} + + ~AutoContrastOp() override = default; + + /// Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const AutoContrastOp &so) { + so.Print(out); + return out; + } + + void Print(std::ostream &out) const override { out << Name(); } + + std::string Name() const override { return kAutoContrastOp; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + private: + float cutoff_; + std::vector ignore_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_AUTO_CONTRAST_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index ddbce3e23..dac076a5f 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -585,6 +585,109 @@ Status AdjustContrast(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &cutoff, + const std::vector &ignore) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(input); + if (!input_cv->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + if (input_cv->Rank() != 3 && input_cv->Rank() != 2) { + RETURN_STATUS_UNEXPECTED("Shape not or "); + } + // Reshape to extend dimension if rank is 2 for algorithm to work. then reshape output to be of rank 2 like input + if (input_cv->Rank() == 2) { + RETURN_IF_NOT_OK(input_cv->ExpandDim(2)); + } + // Get number of channels and image matrix + std::size_t num_of_channels = input_cv->shape()[2]; + if (num_of_channels != 1 && num_of_channels != 3) { + RETURN_STATUS_UNEXPECTED("Number of channels is not 1 or 3."); + } + cv::Mat image = input_cv->mat(); + // Separate the image to channels + std::vector planes(num_of_channels); + cv::split(image, planes); + cv::Mat b_hist, g_hist, r_hist; + // Establish the number of bins and set variables for histogram + int32_t hist_size = 256; + int32_t channels = 0; + float range[] = {0, 256}; + const float *hist_range[] = {range}; + bool uniform = true, accumulate = false; + // Set up lookup table for LUT(Look up table algorithm) + std::vector table; + std::vector image_result; + for (std::size_t layer = 0; layer < planes.size(); layer++) { + // Reset lookup table + table = std::vector{}; + // Calculate Histogram for channel + cv::Mat hist; + cv::calcHist(&planes[layer], 1, &channels, cv::Mat(), hist, 1, &hist_size, hist_range, uniform, accumulate); + hist.convertTo(hist, CV_32SC1); + std::vector hist_vec; + hist.col(0).copyTo(hist_vec); + // Ignore values in ignore + for (const auto &item : ignore) hist_vec[item] = 0; + int32_t n = std::accumulate(hist_vec.begin(), hist_vec.end(), 0); + // Find pixel values that are in the low cutoff and high cutoff. + int32_t cut = static_cast((cutoff / 100.0) * n); + if (cut != 0) { + for (int32_t lo = 0; lo < 256 && cut > 0; lo++) { + if (cut > hist_vec[lo]) { + cut -= hist_vec[lo]; + hist_vec[lo] = 0; + } else { + hist_vec[lo] -= cut; + cut = 0; + } + } + cut = static_cast((cutoff / 100.0) * n); + for (int32_t hi = 255; hi >= 0 && cut > 0; hi--) { + if (cut > hist_vec[hi]) { + cut -= hist_vec[hi]; + hist_vec[hi] = 0; + } else { + hist_vec[hi] -= cut; + cut = 0; + } + } + } + int32_t lo = 0; + int32_t hi = 255; + for (; lo < 256 && !hist_vec[lo]; lo++) { + } + for (; hi >= 0 && !hist_vec[hi]; hi--) { + } + if (hi <= lo) { + for (int32_t i = 0; i < 256; i++) { + table.push_back(i); + } + } else { + float scale = 255.0 / (hi - lo); + float offset = -1 * lo * scale; + for (int32_t i = 0; i < 256; i++) { + int32_t ix = static_cast(i * scale + offset); + ix = std::max(ix, 0); + ix = std::min(ix, 255); + table.push_back(ix); + } + } + cv::Mat result_layer; + cv::LUT(planes[layer], table, result_layer); + image_result.push_back(result_layer); + } + cv::Mat result; + cv::merge(image_result, result); + std::shared_ptr output_cv = std::make_shared(result); + if (input_cv->Rank() == 2) output_cv->Squeeze(); + (*output) = std::static_pointer_cast(output_cv); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("Error in auto contrast"); + } + return Status::OK(); +} + Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr *output, const float &alpha) { try { std::shared_ptr input_cv = CVTensor::AsCVTensor(input); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index f489c7367..c14263389 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -175,6 +175,14 @@ Status AdjustBrightness(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &alpha); +// Returns image with contrast maximized. +// @param input: Tensor of shape // in RGB/Grayscale and any OpenCv compatible type, see CVTensor. +// @param cutoff: Cutoff percentage of how many pixels are to be removed (high pixels change to 255 and low change to 0) +// from the high and low ends of the histogram. +// @param ignore: Pixel values to be ignored in the algorithm. +Status AutoContrast(const std::shared_ptr &input, std::shared_ptr *output, const float &cutoff, + const std::vector &ignore); + // Returns image with adjusted saturation. // @param input: Tensor of shape in RGB order and any OpenCv compatible type, see CVTensor. // @param alpha: Alpha value to adjust saturation by. Should be a positive number. diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 27bcfed00..cae28fe6f 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -87,6 +87,7 @@ namespace mindspore { namespace dataset { // image +constexpr char kAutoContrastOp[] = "AutoContrastOp"; constexpr char kBoundingBoxAugmentOp[] = "BoundingBoxAugmentOp"; constexpr char kDecodeOp[] = "DecodeOp"; constexpr char kCenterCropOp[] = "CenterCropOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index ca356dd79..0715ec8e1 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -47,7 +47,7 @@ from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ - check_random_select_subpolicy_op, FLOAT_MAX_INTEGER + check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -71,6 +71,24 @@ def parse_padding(padding): return padding +class AutoContrast(cde.AutoContrastOp): + """ + Apply auto contrast on input image. + + Args: + cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0). + ignore (int or sequence, optional): Pixel values to ignore (default=None). + """ + + @check_auto_contrast + def __init__(self, cutoff=0.0, ignore=None): + if ignore is None: + ignore = [] + if isinstance(ignore, int): + ignore = [ignore] + super().__init__(cutoff, ignore) + + class Invert(cde.InvertOp): """ Apply invert on input image in RGB mode. diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 0f2bc2ce2..b4ac03488 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -530,6 +530,27 @@ def check_bounding_box_augment_cpp(method): return new_method +def check_auto_contrast(method): + """Wrapper method to check the parameters of AutoContrast ops (python and cpp).""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs) + type_check(cutoff, (int, float), "cutoff") + check_value(cutoff, [0, 100], "cutoff") + if ignore is not None: + type_check(ignore, (list, tuple, int), "ignore") + if isinstance(ignore, int): + check_value(ignore, [0, 255], "ignore") + if isinstance(ignore, (list, tuple)): + for item in ignore: + type_check(item, (int,), "item") + check_value(item, [0, 255], "ignore") + return method(self, *args, **kwargs) + + return new_method + + def check_uniform_augment_py(method): """Wrapper method to check the parameters of python UniformAugment op.""" diff --git a/tests/ut/data/dataset/golden/autcontrast_01_result_py.npz b/tests/ut/data/dataset/golden/autcontrast_01_result_py.npz new file mode 100644 index 0000000000000000000000000000000000000000..6408ebf25080fb5aa983e5a97409c3b897158283 GIT binary patch literal 607 zcmWIWW@Zs#fB;2?b#dl4*BKcYK$w$3gdwr0DBeIXub`5VK>#cWQV5a+fysWMz5$Vp z3}p<}>M5zk$wlf`3hFif>N*PQY57GZMTvRw`9&$IAYr$}oZ?iVcyUHzK`M~1VWgvA zq^YA&t3W>BYG6*zE6pva)Jx7UO4Z9P%_+$Qx;L?sE50Z-IX|zsq^LBxgsYGNqKYdo z1tMF>=*`et$mGnJRLI<3$P!e@s^QJ(&E(D0R>%fbno?3(kjhoa9>E0kroTlYhc|;a zV|yWIP$8FwH**BY|22i&etv#l|A7EZc-xm0@+5V}7V>IDuma_C5|dJM3i)CS`7;`Rmcznn(=3%@~3l0 zALNP1Rh*31wv#PXEGbkTfr;q?|RyL3fBM_Pc KX(3S9F#rIHd!A(g literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py index d212994e6..fd390b548 100644 --- a/tests/ut/python/dataset/test_autocontrast.py +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -16,20 +16,22 @@ Testing AutoContrast op in DE """ import numpy as np - import mindspore.dataset.engine as de import mindspore.dataset.transforms.vision.py_transforms as F +import mindspore.dataset.transforms.vision.c_transforms as C from mindspore import log as logger -from util import visualize_list, diff_mse +from util import visualize_list, diff_mse, save_and_check_md5 DATA_DIR = "../data/dataset/testImageNetData/train/" +GENERATE_GOLDEN = False + -def test_auto_contrast(plot=False): +def test_auto_contrast_py(plot=False): """ Test AutoContrast """ - logger.info("Test AutoContrast") + logger.info("Test AutoContrast Python Op") # Original Images ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) @@ -78,9 +80,156 @@ def test_auto_contrast(plot=False): mse[i] = diff_mse(images_auto_contrast[i], images_original[i]) logger.info("MSE= {}".format(str(np.mean(mse)))) + # Compare with expected md5 from images + filename = "autcontrast_01_result_py.npz" + save_and_check_md5(ds_auto_contrast, filename, generate_golden=GENERATE_GOLDEN) + if plot: visualize_list(images_original, images_auto_contrast) +def test_auto_contrast_c(plot=False): + """ + Test AutoContrast C Op + """ + logger.info("Test AutoContrast C Op") + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + python_op = F.AutoContrast() + c_op = C.AutoContrast() + transforms_op = F.ComposeOp([lambda img: F.ToPIL()(img.astype(np.uint8)), + python_op, + np.array])() + + ds_auto_contrast_py = ds.map(input_columns="image", + operations=transforms_op) + + ds_auto_contrast_py = ds_auto_contrast_py.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_py): + if idx == 0: + images_auto_contrast_py = image + else: + images_auto_contrast_py = np.append(images_auto_contrast_py, + image, + axis=0) + + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + + ds_auto_contrast_c = ds.map(input_columns="image", + operations=c_op) + + ds_auto_contrast_c = ds_auto_contrast_c.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_c): + if idx == 0: + images_auto_contrast_c = image + else: + images_auto_contrast_c = np.append(images_auto_contrast_c, + image, + axis=0) + + num_samples = images_auto_contrast_c.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + np.testing.assert_equal(np.mean(mse), 0.0) + + if plot: + visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2) + + +def test_auto_contrast_one_channel_c(plot=False): + """ + Test AutoContrast C op with one channel + """ + logger.info("Test AutoContrast C Op With One Channel Images") + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224))]) + python_op = F.AutoContrast() + c_op = C.AutoContrast() + # not using F.ToTensor() since it converts to floats + transforms_op = F.ComposeOp([lambda img: (np.array(img)[:, :, 0]).astype(np.uint8), + F.ToPIL(), + python_op, + np.array])() + + ds_auto_contrast_py = ds.map(input_columns="image", + operations=transforms_op) + + ds_auto_contrast_py = ds_auto_contrast_py.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_py): + if idx == 0: + images_auto_contrast_py = image + else: + images_auto_contrast_py = np.append(images_auto_contrast_py, + image, + axis=0) + + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + + ds_auto_contrast_c = ds.map(input_columns="image", + operations=c_op) + + ds_auto_contrast_c = ds_auto_contrast_c.batch(512) + + for idx, (image, _) in enumerate(ds_auto_contrast_c): + if idx == 0: + images_auto_contrast_c = image + else: + images_auto_contrast_c = np.append(images_auto_contrast_c, + image, + axis=0) + + num_samples = images_auto_contrast_c.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_auto_contrast_c[i], images_auto_contrast_py[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + np.testing.assert_equal(np.mean(mse), 0.0) + + if plot: + visualize_list(images_auto_contrast_c, images_auto_contrast_py, visualize_mode=2) + + +def test_auto_contrast_invalid_input_c(): + """ + Test AutoContrast C Op with invalid params + """ + logger.info("Test AutoContrast C Op with invalid params") + try: + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + ds = ds.map(input_columns=["image"], + operations=[C.Decode(), + C.Resize((224, 224)), + lambda img: np.array(img[:, :, 0])]) + # invalid ignore + ds = ds.map(input_columns="image", + operations=C.AutoContrast(ignore=255.5)) + except TypeError as error: + logger.info("Got an exception in DE: {}".format(str(error))) + assert "Argument ignore with value 255.5 is not of type" in str(error) + + if __name__ == "__main__": - test_auto_contrast(plot=True) + test_auto_contrast_py(plot=True) + test_auto_contrast_c(plot=True) + test_auto_contrast_one_channel_c(plot=True) + test_auto_contrast_invalid_input_c() -- GitLab