提交 587e2602 编写于 作者: A avakh

addressing comments

上级 d7a312d0
......@@ -63,12 +63,14 @@
#include "dataset/kernels/image/random_horizontal_flip_bbox_op.h"
#include "dataset/kernels/image/random_horizontal_flip_op.h"
#include "dataset/kernels/image/random_resize_op.h"
#include "dataset/kernels/image/random_resize_with_bbox_op.h"
#include "dataset/kernels/image/random_rotation_op.h"
#include "dataset/kernels/image/random_vertical_flip_op.h"
#include "dataset/kernels/image/random_vertical_flip_with_bbox_op.h"
#include "dataset/kernels/image/rescale_op.h"
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/resize_with_bbox_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/no_op.h"
#include "dataset/text/kernels/jieba_tokenizer_op.h"
......@@ -348,6 +350,18 @@ void bindTensorOps1(py::module *m) {
.def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
(void)py::class_<ResizeWithBBoxOp, TensorOp, std::shared_ptr<ResizeWithBBoxOp>>(
*m, "ResizeWithBBoxOp", "Tensor operation to resize an image. Takes height, width and mode.")
.def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
py::arg("targetWidth") = ResizeWithBBoxOp::kDefWidth,
py::arg("interpolation") = ResizeWithBBoxOp::kDefInterpolation);
(void)py::class_<RandomResizeWithBBoxOp, TensorOp, std::shared_ptr<RandomResizeWithBBoxOp>>(
*m, "RandomResizeWithBBoxOp",
"Tensor operation to resize an image using a randomly selected interpolation. Takes height and width.")
.def(py::init<int32_t, int32_t>(), py::arg("targetHeight"),
py::arg("targetWidth") = RandomResizeWithBBoxOp::kDefTargetWidth);
(void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
*m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
.def(py::init<std::vector<std::shared_ptr<TensorOp>>, int32_t>(), py::arg("operations"),
......
......@@ -25,4 +25,6 @@ add_library(kernels-image OBJECT
resize_bilinear_op.cc
resize_op.cc
uniform_aug_op.cc
resize_with_bbox_op.cc
random_resize_with_bbox_op.cc
)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/image/random_resize_with_bbox_op.h"
#include "dataset/kernels/image/resize_with_bbox_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
const int32_t RandomResizeWithBBoxOp::kDefTargetWidth = 0;
Status RandomResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) {
// Randomly selects from the following four interpolation methods
// 0-bilinear, 1-nearest_neighbor, 2-bicubic, 3-area
interpolation_ = static_cast<InterpolationMode>(distribution_(random_generator_));
RETURN_IF_NOT_OK(ResizeWithBBoxOp::Compute(input, output));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H
#define DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H
#include <memory>
#include <random>
#include "dataset/core/tensor.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/resize_with_bbox_op.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RandomResizeWithBBoxOp : public ResizeWithBBoxOp {
public:
// Default values, also used by python_bindings.cc
static const int32_t kDefTargetWidth;
explicit RandomResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefTargetWidth) : ResizeWithBBoxOp(size_1, size_2) {
random_generator_.seed(GetSeed());
}
~RandomResizeWithBBoxOp() = default;
// Description: A function that prints info about the node
void Print(std::ostream &out) const override {
out << "RandomResizeWithBBoxOp: " << ResizeWithBBoxOp::size1_ << " " << ResizeWithBBoxOp::size2_;
}
Status Compute(const TensorRow &input, TensorRow *output) override;
private:
std::mt19937 random_generator_;
std::uniform_int_distribution<int> distribution_{0, 3};
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_IMAGE_RANDOM_RESIZE_WITH_BBOX_OP_H
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/image/resize_with_bbox_op.h"
#include <utility>
#include <memory>
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/image_utils.h"
#include "dataset/core/cv_tensor.h"
#include "dataset/core/pybind_support.h"
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status ResizeWithBBoxOp::Compute(const TensorRow &input, TensorRow *output) {
IO_CHECK_VECTOR(input, output);
BOUNDING_BOX_CHECK(input);
int32_t input_h = input[0]->shape()[0];
int32_t input_w = input[0]->shape()[1];
output->resize(2);
(*output)[1] = std::move(input[1]); // move boxes over to output
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input[0]));
RETURN_IF_NOT_OK(ResizeOp::Compute(std::static_pointer_cast<Tensor>(input_cv), &(*output)[0]));
int32_t output_h = (*output)[0]->shape()[0]; // output height if ResizeWithBBox
int32_t output_w = (*output)[0]->shape()[1]; // output width if ResizeWithBBox
size_t bboxCount = input[1]->shape()[0]; // number of rows in bbox tensor
RETURN_IF_NOT_OK(UpdateBBoxesForResize((*output)[1], bboxCount, output_w, output_h, input_w, input_h));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H
#define DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H
#include "dataset/core/tensor.h"
#include "dataset/kernels/image/image_utils.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
#include "dataset/kernels/image/resize_op.h"
namespace mindspore {
namespace dataset {
class ResizeWithBBoxOp : public ResizeOp {
public:
// Constructor for ResizeWithBBoxOp, with default value and passing to base class constructor
explicit ResizeWithBBoxOp(int32_t size_1, int32_t size_2 = kDefWidth,
InterpolationMode mInterpolation = kDefInterpolation)
: ResizeOp(size_1, size_2, mInterpolation) {}
~ResizeWithBBoxOp() override = default;
void Print(std::ostream &out) const override { out << "ResizeWithBBoxOp: " << size1_ << " " << size2_; }
Status Compute(const TensorRow &input, TensorRow *output) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_IMAGE_RESIZE_WITH_BBOX_OP_H
......@@ -265,6 +265,7 @@ class BoundingBoxAugment(cde.BoundingBoxAugmentOp):
ratio (float, optional): Ratio of bounding boxes to apply augmentation on.
Range: [0,1] (default=0.3).
"""
@check_bounding_box_augment_cpp
def __init__(self, transform, ratio=0.3):
self.ratio = ratio
......@@ -302,6 +303,36 @@ class Resize(cde.ResizeOp):
super().__init__(*size, interpoltn)
class ResizeWithBBox(cde.ResizeWithBBoxOp):
"""
Resize the input image to the given size and adjust the bounding boxes accordingly.
Args:
size (int or sequence): The output size of the resized image.
If size is an int, smaller edge of the image will be resized to this value with
the same image aspect ratio.
If size is a sequence of length 2, it should be (height, width).
interpolation (Inter mode, optional): Image interpolation mode (default=Inter.LINEAR).
It can be any of [Inter.LINEAR, Inter.NEAREST, Inter.BICUBIC].
- Inter.LINEAR, means interpolation method is bilinear interpolation.
- Inter.NEAREST, means interpolation method is nearest-neighbor interpolation.
- Inter.BICUBIC, means interpolation method is bicubic interpolation.
"""
@check_resize_interpolation
def __init__(self, size, interpolation=Inter.LINEAR):
self.size = size
self.interpolation = interpolation
interpoltn = DE_C_INTER_MODE[interpolation]
if isinstance(size, int):
super().__init__(size, interpolation=interpoltn)
else:
super().__init__(*size, interpoltn)
class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
"""
Crop the input image to a random size and aspect ratio and adjust the Bounding Boxes accordingly
......@@ -326,6 +357,7 @@ class RandomResizedCropWithBBox(cde.RandomCropAndResizeWithBBoxOp):
max_attempts (int, optional): The maximum number of attempts to propose a valid
crop_area (default=10). If exceeded, fall back to use center_crop instead.
"""
@check_random_resize_crop
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation=Inter.BILINEAR, max_attempts=10):
......@@ -499,6 +531,27 @@ class RandomResize(cde.RandomResizeOp):
super().__init__(*size)
class RandomResizeWithBBox(cde.RandomResizeWithBBoxOp):
"""
Tensor operation to resize the input image using a randomly selected interpolation mode and adjust
the bounding boxes accordingly.
Args:
size (int or sequence): The output size of the resized image.
If size is an int, smaller edge of the image will be resized to this value with
the same image aspect ratio.
If size is a sequence of length 2, it should be (height, width).
"""
@check_resize
def __init__(self, size):
self.size = size
if isinstance(size, int):
super().__init__(size)
else:
super().__init__(*size)
class HWC2CHW(cde.ChannelSwapOp):
"""
Transpose the input image; shape (H, W, C) to shape (C, H, W).
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the random resize with bounding boxes op in DE
"""
from enum import Enum
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
import mindspore.dataset.transforms.vision.c_transforms as c_vision
GENERATE_GOLDEN = False
DATA_DIR = "../data/dataset/testVOC2012"
def fix_annotate(bboxes):
"""
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for bbox in bboxes:
tmp = bbox[0]
bbox[0] = bbox[1]
bbox[1] = bbox[2]
bbox[2] = bbox[3]
bbox[3] = bbox[4]
bbox[4] = tmp
return bboxes
class BoxType(Enum):
"""
Defines box types for test cases
"""
WidthOverflow = 1
HeightOverflow = 2
NegativeXY = 3
OnEdge = 4
WrongShape = 5
class AddBadAnnotation: # pylint: disable=too-few-public-methods
"""
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
"""
def __init__(self, box_type):
self.box_type = box_type
def __call__(self, img, bboxes):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:return: bboxes with bad examples added
"""
height = img.shape[0]
width = img.shape[1]
if self.box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes
def check_bad_box(data, box_type, expected_error):
try:
test_op = c_vision.RandomResizeWithBBox(100) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
data = data.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to use width overflow
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=AddBadAnnotation(box_type)) # Add column for "annotation"
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert expected_error in str(e)
def add_bounding_boxes(axis, bboxes):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for bbox in bboxes:
rect = patches.Rectangle((bbox[0], bbox[1]),
bbox[2], bbox[3],
linewidth=1, edgecolor='r', facecolor='none')
# Add the patch to the Axes
axis.add_patch(rect)
def visualize(unaugmented_data, augment_data):
for idx, (un_aug_item, aug_item) in \
enumerate(zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())):
axis = plt.subplot(141)
plt.imshow(un_aug_item["image"])
add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes
plt.title("Original" + str(idx + 1))
logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"])
axis = plt.subplot(142)
plt.imshow(aug_item["image"])
add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes
plt.title("Augmented" + str(idx + 1))
logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n")
plt.show()
def test_random_resize_with_bbox_op(plot=False):
"""
Test random_resize_with_bbox_op
"""
logger.info("Test random resize with bbox")
# original images
data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# augmented images
data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_original = data_original.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
data_augmented = data_augmented.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# define map operations
test_op = c_vision.RandomResizeWithBBox(100) # input value being the target size of resizeOp
data_augmented = data_augmented.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"], operations=[test_op])
if plot:
visualize(data_original, data_augmented)
def test_random_resize_with_bbox_invalid_bounds():
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.WrongShape, "4 features")
def test_random_resize_with_bbox_invalid_size():
"""
Test random_resize_with_bbox_op
"""
logger.info("Test random resize with bbox with invalid target size")
# original images
data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data = data.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# negative target size as input
try:
test_op = c_vision.RandomResizeWithBBox(-10) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
print(e)
assert "Input is not" in str(e)
# zero target size as input
try:
test_op = c_vision.RandomResizeWithBBox(0) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
# invalid input shape
try:
test_op = c_vision.RandomResizeWithBBox((10, 10, 10)) # DEFINE TEST OP HERE -- (PROB 1 IN CASE OF RANDOM)
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Size should be" in str(e)
if __name__ == "__main__":
test_random_resize_with_bbox_op(plot=False)
test_random_resize_with_bbox_invalid_bounds()
test_random_resize_with_bbox_invalid_size()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the resize with bounding boxes op in DE
"""
from enum import Enum
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as c_vision
from mindspore import log as logger
import mindspore.dataset as ds
GENERATE_GOLDEN = False
DATA_DIR = "../data/dataset/testVOC2012"
def fix_annotate(bboxes):
"""
:param bboxes: in [label, x_min, y_min, w, h, truncate, difficult] format
:return: annotation in [x_min, y_min, w, h, label, truncate, difficult] format
"""
for bbox in bboxes:
tmp = bbox[0]
bbox[0] = bbox[1]
bbox[1] = bbox[2]
bbox[2] = bbox[3]
bbox[3] = bbox[4]
bbox[4] = tmp
return bboxes
class BoxType(Enum):
"""
Defines box types for test cases
"""
WidthOverflow = 1
HeightOverflow = 2
NegativeXY = 3
OnEdge = 4
WrongShape = 5
class AddBadAnnotation: # pylint: disable=too-few-public-methods
"""
Used to add erroneous bounding boxes to object detection pipelines.
Usage:
>>> # Adds a box that covers the whole image. Good for testing edge cases
>>> de = de.map(input_columns=["image", "annotation"],
>>> output_columns=["image", "annotation"],
>>> operations=AddBadAnnotation(BoxType.OnEdge))
"""
def __init__(self, box_type):
self.box_type = box_type
def __call__(self, img, bboxes):
"""
Used to generate erroneous bounding box examples on given img.
:param img: image where the bounding boxes are.
:param bboxes: in [x_min, y_min, w, h, label, truncate, difficult] format
:return: bboxes with bad examples added
"""
height = img.shape[0]
width = img.shape[1]
if self.box_type == BoxType.WidthOverflow:
# use box that overflows on width
return img, np.array([[0, 0, width + 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.HeightOverflow:
# use box that overflows on height
return img, np.array([[0, 0, width - 1, height + 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.NegativeXY:
# use box with negative xy
return img, np.array([[-10, -10, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.OnEdge:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1, height - 1, 0, 0, 0]]).astype(np.uint32)
if self.box_type == BoxType.WrongShape:
# use box that covers the whole image
return img, np.array([[0, 0, width - 1]]).astype(np.uint32)
return img, bboxes
def check_bad_box(data, box_type, expected_error):
try:
test_op = c_vision.ResizeWithBBox(100)
data = data.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# map to use width overflow
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=AddBadAnnotation(box_type)) # Add column for "annotation"
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert expected_error in str(e)
def add_bounding_boxes(axis, bboxes):
"""
:param axis: axis to modify
:param bboxes: bounding boxes to draw on the axis
:return: None
"""
for bbox in bboxes:
rect = patches.Rectangle((bbox[0], bbox[1]),
bbox[2], bbox[3],
linewidth=1, edgecolor='r', facecolor='none')
# Add the patch to the Axes
axis.add_patch(rect)
def visualize(unaugmented_data, augment_data):
for idx, (un_aug_item, aug_item) in enumerate(
zip(unaugmented_data.create_dict_iterator(), augment_data.create_dict_iterator())):
axis = plt.subplot(141)
plt.imshow(un_aug_item["image"])
add_bounding_boxes(axis, un_aug_item["annotation"]) # add Orig BBoxes
plt.title("Original" + str(idx + 1))
logger.info("Original ", str(idx + 1), " :", un_aug_item["annotation"])
axis = plt.subplot(142)
plt.imshow(aug_item["image"])
add_bounding_boxes(axis, aug_item["annotation"]) # add AugBBoxes
plt.title("Augmented" + str(idx + 1))
logger.info("Augmented ", str(idx + 1), " ", aug_item["annotation"], "\n")
plt.show()
def test_resize_with_bbox_op(plot=False):
"""
Test resize_with_bbox_op
"""
logger.info("Test resize with bbox")
# original images
data_original = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
# augmented images
data_augmented = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data_original = data_original.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
data_augmented = data_augmented.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# define map operations
test_op = c_vision.ResizeWithBBox(100) # input value being the target size of resizeOp
data_augmented = data_augmented.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"], operations=[test_op])
if plot:
visualize(data_original, data_augmented)
def test_resize_with_bbox_invalid_bounds():
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.WidthOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.HeightOverflow, "bounding boxes is out of bounds of the image")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.NegativeXY, "min_x")
data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
check_bad_box(data_voc2, BoxType.WrongShape, "4 features")
def test_resize_with_bbox_invalid_size():
"""
Test resize_with_bbox_op
"""
logger.info("Test resize with bbox with invalid target size")
# original images
data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data = data.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# negative target size as input
try:
test_op = c_vision.ResizeWithBBox(-10)
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
# zero target size as input
try:
test_op = c_vision.ResizeWithBBox(0)
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input is not" in str(e)
# invalid input shape
try:
test_op = c_vision.ResizeWithBBox((10, 10, 10))
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Size should be" in str(e)
def test_resize_with_bbox_invalid_interpolation():
"""
Test resize_with_bbox_op
"""
logger.info("Test resize with bbox with invalid interpolation size")
# original images
data = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False)
data = data.map(input_columns=["annotation"],
output_columns=["annotation"],
operations=fix_annotate)
# invalid interpolation
try:
test_op = c_vision.ResizeWithBBox(100, interpolation="invalid")
# map to apply ops
data = data.map(input_columns=["image", "annotation"],
output_columns=["image", "annotation"],
columns_order=["image", "annotation"],
operations=[test_op]) # Add column for "annotation"
for _, _ in enumerate(data.create_dict_iterator()):
break
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "interpolation" in str(e)
if __name__ == "__main__":
test_resize_with_bbox_op(plot=False)
test_resize_with_bbox_invalid_bounds()
test_resize_with_bbox_invalid_size()
test_resize_with_bbox_invalid_interpolation()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册