diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 7bed870f1a0fdbdfc1ecba8d2c1786d9f6659fe7..ed3f993fb86571adee90b1ef031ec7763a11b44e 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -63,12 +63,14 @@ #include "dataset/kernels/image/random_horizontal_flip_bbox_op.h" #include "dataset/kernels/image/random_horizontal_flip_op.h" #include "dataset/kernels/image/random_resize_op.h" +#include "dataset/kernels/image/random_resize_with_bbox_op.h" #include "dataset/kernels/image/random_rotation_op.h" #include "dataset/kernels/image/random_vertical_flip_op.h" #include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h" #include "dataset/kernels/image/rescale_op.h" #include "dataset/kernels/image/resize_bilinear_op.h" #include "dataset/kernels/image/resize_op.h" +#include "dataset/kernels/image/resize_with_bbox_op.h" #include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/no_op.h" #include "dataset/text/kernels/jieba_tokenizer_op.h" @@ -348,6 +350,18 @@ void bindTensorOps1(py::module *m) { .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); + (void)py::class_>( + *m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth, + py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation); + + (void)py::class_>( + *m, "RandomResizeWithBBoxOp", + "Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.") + .def(py::init(), py::arg("targetHeight"), + py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth); + (void)py::class_>( *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") .def(py::init>, int32_t>(), py::arg("operations"), diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt index 3d88d9989cae9152a6e88c115eb09b90718c7c69..fef698912c470baecdb680f050862b171cc9b0dd 100644 --- a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt @@ -25,4 +25,6 @@ add_library(kernels-image OBJECT resize_bilinear_op.cc resize_op.cc uniform_aug_op.cc + resize_with_bbox_op.cc + random_resize_with_bbox_op.cc ) diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..de69c02e39e8a8ca36f6eca00b0c5450f3758dd2 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/kernels/image/random_resize_with_bbox_op.h" +#include "dataset/kernels/image/resize_with_bbox_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0; + +Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + // Randomly selects from the following four interpolation methods + // 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area + interpolation_ = static_cast(distribution_(random_generator_)); + RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4a7614525fbd622d75d0a8210620ec5f7fad139b --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/random_resize_with_bbox_op.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H + +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/image/resize_op.h" +#include "dataset/kernels/image/resize_with_bbox_op.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/random.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RandomResizeWithBBoxOp : public ResizeWithBBoxOp { + public: + // Default values, also used by python_bindings.cc + static const int32_t kDefTargetWidth; + explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) { + random_generator_.seed(GetSeed()); + } + + ~RandomResizeWithBBoxOp() = default; + + // Description: A function that prints info about the node + void Print(std::ostream &out) const override { + out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_; + } + + Status Compute(const TensorRow &input, TensorRow *output) override; + + private: + std::mt19937 random_generator_; + std::uniform_int_distribution distribution_{0, 3}; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8a633d5678cfbc92148a11b83f825deb0d44fe06 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dataset/kernels/image/resize_with_bbox_op.h" +#include +#include +#include "dataset/kernels/image/resize_op.h" +#include "dataset/kernels/image/image_utils.h" +#include "dataset/core/cv_tensor.h" +#include "dataset/core/pybind_support.h" +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + BOUNDING_BOX_CHECK(input); + + int32_t input_h = input[0]->shape()[0]; + int32_t input_w = input[0]->shape()[1]; + + output->resize(2); + (*output)[1] = std::move(input[1]); // move boxes over to output + + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input[0])); + + RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast(input_cv), &(*output)[0])); + + int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox + int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox + + size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor + RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h)); + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h new file mode 100644 index 0000000000000000000000000000000000000000..17bdd01ef121f24015d8e522ff8c7453d7bce982 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/resize_with_bbox_op.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H +#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H + +#include "dataset/core/tensor.h" +#include "dataset/kernels/image/image_utils.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" +#include "dataset/kernels/image/resize_op.h" + +namespace mindspore { +namespace dataset { +class ResizeWithBBoxOp : public ResizeOp { + public: + // Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor + explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth, + InterpolationMode mInterpolation = kDefInterpolation) + : ResizeOp(size_1, size_2, mInterpolation) {} + + ~ResizeWithBBoxOp() override = default; + + void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; } + + Status Compute(const TensorRow &input, TensorRow *output) override; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index c2497f9629c6d4f291e741e37d8a99d2ba16cdd9..aef714953f06e030ba0f5815224a0ffa80677ce0 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -265,6 +265,7 @@ class BoundingBoxAugment(cde.BoundingBoxAugmentOp): ratio (float, optional): Ratio of bounding boxes to apply augmentation on. Range: [0,1] (default=0.3). """ + @check_bounding_box_augment_cpp def __init__(self, transform, ratio=0.3): self.ratio = ratio @@ -302,6 +303,36 @@ class Resize(cde.ResizeOp): super().__init__(*size, interpoltn) +class ResizeWithBBox(cde.ResizeWithBBoxOp): + """ + Resize the input image to the given size and adjust the bounding boxes accordingly. + + Args: + size (int or sequence): The output size of the resized image. + If size is an int, smaller edge of the image will be resized to this value with + the same image aspect ratio. + If size is a sequence of length 2, it should be (height, width). + interpolation (Inter mode, optional): Image interpolation mode (default=Inter.LINEAR). + It can be any of [Inter.LINEAR, Inter.NEAREST, Inter.BICUBIC]. + + - Inter.LINEAR, means interpolation method is bilinear interpolation. + + - Inter.NEAREST, means interpolation method is nearest-neighbor interpolation. + + - Inter.BICUBIC, means interpolation method is bicubic interpolation. + """ + + @check_resize_interpolation + def __init__(self, size, interpolation=Inter.LINEAR): + self.size = size + self.interpolation = interpolation + interpoltn = DE_C_INTER_MODE[interpolation] + if isinstance(size, int): + super().__init__(size, interpolation=interpoltn) + else: + super().__init__(*size, interpoltn) + + class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): """ Crop the input image to a random size and aspect ratio and adjust the Bounding Boxes accordingly @@ -326,6 +357,7 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp): max_attempts (int, optional): The maximum number of attempts to propose a valid crop_area (default=10). If exceeded, fall back to use center_crop instead. """ + @check_random_resize_crop def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Inter.BILINEAR, max_attempts=10): @@ -499,6 +531,27 @@ class RandomResize(cde.RandomResizeOp): super().__init__(*size) +class RandomResizeWithBBox(cde.RandomResizeWithBBoxOp): + """ + Tensor operation to resize the input image using a randomly selected interpolation mode and adjust + the bounding boxes accordingly. + + Args: + size (int or sequence): The output size of the resized image. + If size is an int, smaller edge of the image will be resized to this value with + the same image aspect ratio. + If size is a sequence of length 2, it should be (height, width). + """ + + @check_resize + def __init__(self, size): + self.size = size + if isinstance(size, int): + super().__init__(size) + else: + super().__init__(*size) + + class HWC2CHW(cde.ChannelSwapOp): """ Transpose the input image; shape (H, W, C) to shape (C, H, W). diff --git a/tests/ut/python/dataset/test_random_resize_with_bbox.py b/tests/ut/python/dataset/test_random_resize_with_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..66c185d647924e1d3cb41dc89676159dbd7d750f --- /dev/null +++ b/tests/ut/python/dataset/test_random_resize_with_bbox.py @@ -0,0 +1,265 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing the random resize with bounding boxes op in DE +""" +from enum import Enum +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.transforms.vision.c_transforms as c_vision + + +GENERATE_GOLDEN = False + +DATA_DIR = "../data/dataset/testVOC2012" + + +def fix_annotate(bboxes): + """ + :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format + :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format + """ + for bbox in bboxes: + tmp = bbox[0] + bbox[0] = bbox[1] + bbox[1] = bbox[2] + bbox[2] = bbox[3] + bbox[3] = bbox[4] + bbox[4] = tmp + return bboxes + + +class BoxType(Enum): + """ + Defines box types for test cases + """ + WidthOverflow = 1 + HeightOverflow = 2 + NegativeXY = 3 + OnEdge = 4 + WrongShape = 5 + + +class AddBadAnnotation: # pylint: disable=too-few-public-methods + """ + Used to add erroneous bounding boxes to object detection pipelines. + Usage: + >>> # Adds a box that covers the whole image. Good for testing edge cases + >>> de = de.map(input_columns=["image", "annotation"], + >>> output_columns=["image", "annotation"], + >>> operations=AddBadAnnotation(BoxType.OnEdge)) + """ + + def __init__(self, box_type): + self.box_type = box_type + + def __call__(self, img, bboxes): + """ + Used to generate erroneous bounding box examples on given img. + :param img: image where the bounding boxes are. + :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format + :return: bboxes with bad examples added + """ + height = img.shape[0] + width = img.shape[1] + if self.box_type == BoxType.WidthOverflow: + # use box that overflows on width + return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.HeightOverflow: + # use box that overflows on height + return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.NegativeXY: + # use box with negative xy + return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.OnEdge: + # use box that covers the whole image + return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.WrongShape: + # use box that covers the whole image + return img, np.array([[0, 0, width - 1]]).astype(np.uint32) + return img, bboxes + + +def check_bad_box(data, box_type, expected_error): + try: + test_op = c_vision.RandomResizeWithBBox(100) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + data = data.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + # map to use width overflow + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=AddBadAnnotation(box_type)) # Add column for "annotation" + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + for _, _ in enumerate(data.create_dict_iterator()): + break + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert expected_error in str(e) + + +def add_bounding_boxes(axis, bboxes): + """ + :param axis: axis to modify + :param bboxes: bounding boxes to draw on the axis + :return: None + """ + for bbox in bboxes: + rect = patches.Rectangle((bbox[0], bbox[1]), + bbox[2], bbox[3], + linewidth=1, edgecolor='r', facecolor='none') + # Add the patch to the Axes + axis.add_patch(rect) + + +def visualize(unaugmented_data, augment_data): + for idx, (un_aug_item, aug_item) in \ + enumerate(zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): + axis = plt.subplot(141) + plt.imshow(un_aug_item["image"]) + add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes + plt.title("Original" + str(idx + 1)) + logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) + + axis = plt.subplot(142) + plt.imshow(aug_item["image"]) + add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes + plt.title("Augmented" + str(idx + 1)) + logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") + plt.show() + + +def test_random_resize_with_bbox_op(plot=False): + """ + Test random_resize_with_bbox_op + """ + logger.info("Test random resize with bbox") + + # original images + data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + # augmented images + data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + data_original = data_original.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + data_augmented = data_augmented.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # define map operations + test_op = c_vision.RandomResizeWithBBox(100) # input value being the target size of resizeOp + + data_augmented = data_augmented.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], operations=[test_op]) + if plot: + visualize(data_original, data_augmented) + + +def test_random_resize_with_bbox_invalid_bounds(): + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.WrongShape, "4 features") + + +def test_random_resize_with_bbox_invalid_size(): + """ + Test random_resize_with_bbox_op + """ + logger.info("Test random resize with bbox with invalid target size") + + # original images + data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + data = data.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # negative target size as input + try: + test_op = c_vision.RandomResizeWithBBox(-10) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + print(e) + assert "Input is not" in str(e) + + # zero target size as input + try: + test_op = c_vision.RandomResizeWithBBox(0) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Input is not" in str(e) + + # invalid input shape + try: + test_op = c_vision.RandomResizeWithBBox((10, 10, 10)) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Size should be" in str(e) + +if __name__ == "__main__": + test_random_resize_with_bbox_op(plot=False) + test_random_resize_with_bbox_invalid_bounds() + test_random_resize_with_bbox_invalid_size() diff --git a/tests/ut/python/dataset/test_resize_with_bbox.py b/tests/ut/python/dataset/test_resize_with_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..8b07f17f1a890209e982d79830adce0e4a7e4f08 --- /dev/null +++ b/tests/ut/python/dataset/test_resize_with_bbox.py @@ -0,0 +1,295 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing the resize with bounding boxes op in DE +""" +from enum import Enum +import numpy as np +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import mindspore.dataset.transforms.vision.c_transforms as c_vision +from mindspore import log as logger +import mindspore.dataset as ds + +GENERATE_GOLDEN = False + +DATA_DIR = "../data/dataset/testVOC2012" + + +def fix_annotate(bboxes): + """ + :param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format + :return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format + """ + for bbox in bboxes: + tmp = bbox[0] + bbox[0] = bbox[1] + bbox[1] = bbox[2] + bbox[2] = bbox[3] + bbox[3] = bbox[4] + bbox[4] = tmp + return bboxes + + +class BoxType(Enum): + """ + Defines box types for test cases + """ + WidthOverflow = 1 + HeightOverflow = 2 + NegativeXY = 3 + OnEdge = 4 + WrongShape = 5 + + +class AddBadAnnotation: # pylint: disable=too-few-public-methods + """ + Used to add erroneous bounding boxes to object detection pipelines. + Usage: + >>> # Adds a box that covers the whole image. Good for testing edge cases + >>> de = de.map(input_columns=["image", "annotation"], + >>> output_columns=["image", "annotation"], + >>> operations=AddBadAnnotation(BoxType.OnEdge)) + """ + + def __init__(self, box_type): + self.box_type = box_type + + def __call__(self, img, bboxes): + """ + Used to generate erroneous bounding box examples on given img. + :param img: image where the bounding boxes are. + :param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format + :return: bboxes with bad examples added + """ + height = img.shape[0] + width = img.shape[1] + if self.box_type == BoxType.WidthOverflow: + # use box that overflows on width + return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.HeightOverflow: + # use box that overflows on height + return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.NegativeXY: + # use box with negative xy + return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.OnEdge: + # use box that covers the whole image + return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32) + + if self.box_type == BoxType.WrongShape: + # use box that covers the whole image + return img, np.array([[0, 0, width - 1]]).astype(np.uint32) + return img, bboxes + + +def check_bad_box(data, box_type, expected_error): + try: + test_op = c_vision.ResizeWithBBox(100) + data = data.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + # map to use width overflow + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=AddBadAnnotation(box_type)) # Add column for "annotation" + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + for _, _ in enumerate(data.create_dict_iterator()): + break + except RuntimeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert expected_error in str(e) + + +def add_bounding_boxes(axis, bboxes): + """ + :param axis: axis to modify + :param bboxes: bounding boxes to draw on the axis + :return: None + """ + for bbox in bboxes: + rect = patches.Rectangle((bbox[0], bbox[1]), + bbox[2], bbox[3], + linewidth=1, edgecolor='r', facecolor='none') + # Add the patch to the Axes + axis.add_patch(rect) + + +def visualize(unaugmented_data, augment_data): + for idx, (un_aug_item, aug_item) in enumerate( + zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())): + axis = plt.subplot(141) + plt.imshow(un_aug_item["image"]) + add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes + plt.title("Original" + str(idx + 1)) + logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"]) + + axis = plt.subplot(142) + plt.imshow(aug_item["image"]) + add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes + plt.title("Augmented" + str(idx + 1)) + logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n") + plt.show() + + +def test_resize_with_bbox_op(plot=False): + """ + Test resize_with_bbox_op + """ + logger.info("Test resize with bbox") + + # original images + data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + # augmented images + data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + data_original = data_original.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + data_augmented = data_augmented.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # define map operations + test_op = c_vision.ResizeWithBBox(100) # input value being the target size of resizeOp + + data_augmented = data_augmented.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], operations=[test_op]) + if plot: + visualize(data_original, data_augmented) + + +def test_resize_with_bbox_invalid_bounds(): + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.NegativeXY, "min_x") + data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + check_bad_box(data_voc2, BoxType.WrongShape, "4 features") + + +def test_resize_with_bbox_invalid_size(): + """ + Test resize_with_bbox_op + """ + logger.info("Test resize with bbox with invalid target size") + + # original images + data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + data = data.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # negative target size as input + try: + test_op = c_vision.ResizeWithBBox(-10) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Input is not" in str(e) + + # zero target size as input + try: + test_op = c_vision.ResizeWithBBox(0) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Input is not" in str(e) + + # invalid input shape + try: + test_op = c_vision.ResizeWithBBox((10, 10, 10)) + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except TypeError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "Size should be" in str(e) + + +def test_resize_with_bbox_invalid_interpolation(): + """ + Test resize_with_bbox_op + """ + logger.info("Test resize with bbox with invalid interpolation size") + + # original images + data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) + + data = data.map(input_columns=["annotation"], + output_columns=["annotation"], + operations=fix_annotate) + + # invalid interpolation + try: + test_op = c_vision.ResizeWithBBox(100, interpolation="invalid") + + # map to apply ops + data = data.map(input_columns=["image", "annotation"], + output_columns=["image", "annotation"], + columns_order=["image", "annotation"], + operations=[test_op]) # Add column for "annotation" + + for _, _ in enumerate(data.create_dict_iterator()): + break + + except ValueError as e: + logger.info("Got an exception in DE: {}".format(str(e))) + assert "interpolation" in str(e) + +if __name__ == "__main__": + test_resize_with_bbox_op(plot=False) + test_resize_with_bbox_invalid_bounds() + test_resize_with_bbox_invalid_size() + test_resize_with_bbox_invalid_interpolation()