From 1896950ae5e3db4e734a10ae0e1d292d036f78b8 Mon Sep 17 00:00:00 2001 From: Mahdi Date: Thu, 30 Jul 2020 13:24:04 -0400 Subject: [PATCH] Added Mixup --- .../dataset/kernels/image/bindings.cc | 7 + .../ccsrc/minddata/dataset/api/transforms.cc | 50 ++++ .../minddata/dataset/include/transforms.h | 43 +++ .../dataset/kernels/data/data_utils.cc | 26 ++ .../dataset/kernels/data/data_utils.h | 11 + .../dataset/kernels/image/CMakeLists.txt | 1 + .../dataset/kernels/image/mixup_batch_op.cc | 108 ++++++++ .../dataset/kernels/image/mixup_batch_op.h | 51 ++++ .../minddata/dataset/kernels/tensor_op.h | 1 + .../dataset/transforms/vision/c_transforms.py | 30 ++- .../dataset/transforms/vision/validators.py | 13 + tests/ut/cpp/dataset/CMakeLists.txt | 1 + tests/ut/cpp/dataset/c_api_transforms_test.cc | 121 +++++++++ tests/ut/cpp/dataset/mixup_batch_op_test.cc | 69 +++++ .../dataset/golden/mixup_batch_c_result.npz | Bin 0 -> 713 bytes tests/ut/python/dataset/test_mixup_op.py | 247 ++++++++++++++++++ 16 files changed, 776 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.h create mode 100644 tests/ut/cpp/dataset/mixup_batch_op_test.cc create mode 100644 tests/ut/data/dataset/golden/mixup_batch_c_result.npz create mode 100644 tests/ut/python/dataset/test_mixup_op.py diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc index b2873cd5a..ae351625d 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/kernels/image/bindings.cc @@ -28,6 +28,7 @@ #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" #include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/invert_op.h" +#include "minddata/dataset/kernels/image/mixup_batch_op.h" #include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h" @@ -92,6 +93,12 @@ PYBIND_REGISTER(CenterCropOp, 1, ([](const py::module *m) { .def(py::init(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); })); +PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "MixUpBatchOp", "Tensor operation to mixup a batch of images") + .def(py::init(), py::arg("alpha")); + })); + PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) { (void)py::class_>( *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") diff --git a/mindspore/ccsrc/minddata/dataset/api/transforms.cc b/mindspore/ccsrc/minddata/dataset/api/transforms.cc index a68fc7747..6847f74f0 100644 --- a/mindspore/ccsrc/minddata/dataset/api/transforms.cc +++ b/mindspore/ccsrc/minddata/dataset/api/transforms.cc @@ -21,7 +21,9 @@ #include "minddata/dataset/kernels/image/crop_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/decode_op.h" +#include "minddata/dataset/kernels/image/mixup_batch_op.h" #include "minddata/dataset/kernels/image/normalize_op.h" +#include "minddata/dataset/kernels/data/one_hot_op.h" #include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_crop_op.h" @@ -81,6 +83,16 @@ std::shared_ptr Decode(bool rgb) { return op; } +// Function to create MixUpBatchOperation. +std::shared_ptr MixUpBatch(float alpha) { + auto op = std::make_shared(alpha); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + // Function to create NormalizeOperation. std::shared_ptr Normalize(std::vector mean, std::vector std) { auto op = std::make_shared(mean, std); @@ -91,6 +103,16 @@ std::shared_ptr Normalize(std::vector mean, std::vect return op; } +// Function to create OneHotOperation. +std::shared_ptr OneHot(int32_t num_classes) { + auto op = std::make_shared(num_classes); + // Input validation + if (!op->ValidateParams()) { + return nullptr; + } + return op; +} + // Function to create PadOperation. std::shared_ptr Pad(std::vector padding, std::vector fill_value, BorderType padding_mode) { @@ -271,6 +293,20 @@ bool DecodeOperation::ValidateParams() { return true; } std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } +// MixUpOperation +MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {} + +bool MixUpBatchOperation::ValidateParams() { + if (alpha_ < 0) { + MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_; + return false; + } + + return true; +} + +std::shared_ptr MixUpBatchOperation::Build() { return std::make_shared(alpha_); } + // NormalizeOperation NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} @@ -292,6 +328,20 @@ std::shared_ptr NormalizeOperation::Build() { return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); } +// OneHotOperation +OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {} + +bool OneHotOperation::ValidateParams() { + if (num_classes_ < 0) { + MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_; + return false; + } + + return true; +} + +std::shared_ptr OneHotOperation::Build() { return std::make_shared(num_classes_); } + // PadOperation PadOperation::PadOperation(std::vector padding, std::vector fill_value, BorderType padding_mode) : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/transforms.h b/mindspore/ccsrc/minddata/dataset/include/transforms.h index 1788f9ce5..dcec3763e 100644 --- a/mindspore/ccsrc/minddata/dataset/include/transforms.h +++ b/mindspore/ccsrc/minddata/dataset/include/transforms.h @@ -51,7 +51,9 @@ class CenterCropOperation; class CropOperation; class CutOutOperation; class DecodeOperation; +class MixUpBatchOperation; class NormalizeOperation; +class OneHotOperation; class PadOperation; class RandomColorAdjustOperation; class RandomCropOperation; @@ -90,6 +92,13 @@ std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1) /// \return Shared pointer to the current TensorOperation. std::shared_ptr Decode(bool rgb = true); +/// \brief Function to create a MixUpBatch TensorOperation. +/// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and +/// Batch must be called before calling this function. +/// \param[in] alpha hyperparameter of beta distribution (default = 1.0) +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr MixUpBatch(float alpha = 1); + /// \brief Function to create a Normalize TensorOperation. /// \notes Normalize the input image with respect to mean and standard deviation. /// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. @@ -97,6 +106,12 @@ std::shared_ptr Decode(bool rgb = true); /// \return Shared pointer to the current TensorOperation. std::shared_ptr Normalize(std::vector mean, std::vector std); +/// \brief Function to create a OneHot TensorOperation. +/// \notes Convert the labels into OneHot format. +/// \param[in] num_classes number of classes. +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr OneHot(int32_t num_classes); + /// \brief Function to create a Pad TensorOp /// \notes Pads the image according to padding parameters /// \param[in] padding A vector representing the number of pixels to pad the image @@ -258,6 +273,20 @@ class DecodeOperation : public TensorOperation { bool rgb_; }; +class MixUpBatchOperation : public TensorOperation { + public: + explicit MixUpBatchOperation(float alpha = 1); + + ~MixUpBatchOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float alpha_; +}; + class NormalizeOperation : public TensorOperation { public: NormalizeOperation(std::vector mean, std::vector std); @@ -273,6 +302,20 @@ class NormalizeOperation : public TensorOperation { std::vector std_; }; +class OneHotOperation : public TensorOperation { + public: + explicit OneHotOperation(int32_t num_classes_); + + ~OneHotOperation() = default; + + std::shared_ptr Build() override; + + bool ValidateParams() override; + + private: + float num_classes_; +}; + class PadOperation : public TensorOperation { public: PadOperation(std::vector padding, std::vector fill_value = {0}, diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc index 5632dddee..29fd5ada8 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/data_type.h" @@ -648,5 +649,30 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std:: return Status::OK(); } +Status BatchTensorToCVTensorVector(const std::shared_ptr &input, + std::vector> *output) { + std::vector tensor_shape = input->shape().AsVector(); + TensorShape remaining({-1}); + std::vector index(tensor_shape.size(), 0); + if (tensor_shape.size() <= 1) { + RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack"); + } + TensorShape element_shape(std::vector(tensor_shape.begin() + 1, tensor_shape.end())); + + for (; index[0] < tensor_shape[0]; index[0]++) { + uchar *start_addr_of_index = nullptr; + std::shared_ptr out; + + RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out)); + std::shared_ptr cv_out = CVTensor::AsCVTensor(std::move(out)); + if (!cv_out->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + output->push_back(cv_out); + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h index 5e82b4102..4fba6aef9 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/data_utils.h @@ -152,6 +152,17 @@ Status Mask(const std::shared_ptr &input, std::shared_ptr *outpu Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr prepend, std::shared_ptr append); + +// helper for concat, always append to the input, and pass that to the output +Status ConcatenateHelper(const std::shared_ptr &input, std::shared_ptr *output, int8_t axis, + std::shared_ptr append); + +/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors +/// @param input[in] input tensor +/// @param output[out] output tensor +/// @return Status ok/error +Status BatchTensorToCVTensorVector(const std::shared_ptr &input, + std::vector> *output); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index fc4a6790b..9f55aee32 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(kernels-image OBJECT hwc_to_chw_op.cc image_utils.cc invert_op.cc + mixup_batch_op.cc normalize_op.cc pad_op.cc random_color_adjust_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc new file mode 100644 index 000000000..6386ba121 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.cc @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/mixup_batch_op.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); } + +Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) { + if (input.size() < 2) { + RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation"); + } + + std::vector> images; + std::vector image_shape = input.at(0)->shape().AsVector(); + std::vector label_shape = input.at(1)->shape().AsVector(); + + // Check inputs + if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) { + RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch"); + } + + if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) { + RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW"); + } + + // Move images into a vector of CVTensors + RETURN_IF_NOT_OK(BatchTensorToCVTensorVector(input.at(0), &images)); + + // Calculating lambda + // If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1) + // then x = x1 / (x1+x2) is a random variable from Beta(a1, a2) + std::gamma_distribution distribution(alpha_, 1); + float x1 = distribution(rnd_); + float x2 = distribution(rnd_); + float lam = x1 / (x1 + x2); + + // Calculate random labels + std::vector rand_indx; + for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i); + std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_); + + // Compute labels + std::shared_ptr out_labels; + RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32"))); + for (int64_t i = 0; i < label_shape[0]; i++) { + for (int64_t j = 0; j < label_shape[1]; j++) { + uint64_t first_value, second_value; + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j})); + RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j})); + RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value)); + } + } + + // Compute images + for (int64_t i = 0; i < images.size(); i++) { + TensorShape remaining({-1}); + uchar *start_addr_of_index = nullptr; + std::shared_ptr out; + RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(input.at(0)->CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}), + input.at(0)->type(), start_addr_of_index, &out)); + std::shared_ptr rand_image = CVTensor::AsCVTensor(std::move(out)); + if (!rand_image->mat().data) { + RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor"); + } + images[i]->mat() = lam * images[i]->mat() + (1 - lam) * rand_image->mat(); + } + + // Move the output into a TensorRow + std::shared_ptr output_image; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input.at(0)->shape(), input.at(0)->type(), &output_image)); + for (int64_t i = 0; i < images.size(); i++) { + RETURN_IF_NOT_OK(output_image->InsertTensor({i}, images[i])); + } + output->push_back(output_image); + output->push_back(out_labels); + + return Status::OK(); +} + +void MixUpBatchOp::Print(std::ostream &out) const { + out << "MixUpBatchOp: " + << "alpha: " << alpha_ << "\n"; +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.h new file mode 100644 index 000000000..de6b21223 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/mixup_batch_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class MixUpBatchOp : public TensorOp { + public: + // Default values, also used by python_bindings.cc + + explicit MixUpBatchOp(float alpha); + + ~MixUpBatchOp() override = default; + + void Print(std::ostream &out) const override; + + Status Compute(const TensorRow &input, TensorRow *output) override; + + std::string Name() const override { return kMixUpBatchOp; } + + private: + float alpha_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index b6fad3133..f414890ef 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -99,6 +99,7 @@ constexpr char kCropOp[] = "CropOp"; constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp"; constexpr char kInvertOp[] = "InvertOp"; +constexpr char kMixUpBatchOp[] = "MixUpBatchOp"; constexpr char kNormalizeOp[] = "NormalizeOp"; constexpr char kPadOp[] = "PadOp"; constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 9a07c58a1..07e2952fd 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -45,9 +45,9 @@ import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ - check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ - check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ - check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER + check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ + check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \ + check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -130,6 +130,30 @@ class CutOut(cde.CutOutOp): super().__init__(length, length, num_patches, False, *fill_value) +class MixUpBatch(cde.MixUpBatchOp): + """ + Apply MixUp transformation on input batch of images and labels. Each image is multiplied by a random weight (lambda) + and then added to a randomly selected image from the batch multiplied by (1 - lambda). Same formula is also applied + to the one-hot labels. + Note that you need to make labels into one-hot format and batch before calling this function. + + Args: + alpha (float): hyperparameter of beta distribution (default = 1.0). + + Examples: + >>> one_hot_op = data.OneHot(num_classes=10) + >>> data = data.map(input_columns=["label"], operations=one_hot_op) + >>> mixup_batch_op = vision.MixUpBatch() + >>> data = data.batch(5) + >>> data = data.map(input_columns=["image", "label"], operations=mixup_batch_op) + """ + + @check_mix_up_batch_c + def __init__(self, alpha=1.0): + self.alpha = alpha + super().__init__(alpha) + + class Normalize(cde.NormalizeOp): """ Normalize the input image with respect to mean and standard deviation. diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index f140673f3..ad0f428d8 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -47,6 +47,19 @@ def check_resize_size(size): raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") +def check_mix_up_batch_c(method): + """Wrapper method to check the parameters of MixUpBatch.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [alpha], _ = parse_user_args(method, *args, **kwargs) + check_pos_float32(alpha) + + return method(self, *args, **kwargs) + + return new_method + + def check_normalize_c_param(mean, std): if len(mean) != len(std): raise ValueError("Length of mean and std must be equal") diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index 799856fd6..fedbe408b 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -27,6 +27,7 @@ SET(DE_UT_SRCS main_test.cc map_op_test.cc mind_record_op_test.cc + mixup_batch_op_test.cc memory_pool_test.cc normalize_op_test.cc one_hot_op_test.cc diff --git a/tests/ut/cpp/dataset/c_api_transforms_test.cc b/tests/ut/cpp/dataset/c_api_transforms_test.cc index 343ef473e..75e803170 100644 --- a/tests/ut/cpp/dataset/c_api_transforms_test.cc +++ b/tests/ut/cpp/dataset/c_api_transforms_test.cc @@ -146,6 +146,127 @@ TEST_F(MindDataTestPipeline, TestRandomFlip) { iter->Stop(); } +TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr mixup_batch_op = vision::MixUpBatch(0.5); + EXPECT_NE(mixup_batch_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({mixup_batch_op}, {"image", "label"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr mixup_batch_op = vision::MixUpBatch(); + EXPECT_NE(mixup_batch_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({mixup_batch_op}, {"image", "label"}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map> row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image->shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 2); + + // Manually terminate the pipeline + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { + // Create a Cifar10 Dataset + std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; + std::shared_ptr ds = Cifar10(folder_path, RandomSampler(false, 10)); + EXPECT_NE(ds, nullptr); + + // Create a Batch operation on ds + int32_t batch_size = 5; + ds = ds->Batch(batch_size); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr one_hot_op = vision::OneHot(10); + EXPECT_NE(one_hot_op, nullptr); + + // Create a Map operation on ds + ds = ds->Map({one_hot_op},{"label"}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr mixup_batch_op = vision::MixUpBatch(-1); + EXPECT_EQ(mixup_batch_op, nullptr); +} + TEST_F(MindDataTestPipeline, TestPad) { // Create an ImageFolder Dataset std::string folder_path = datasets_root_path_ + "/testPK/data/"; diff --git a/tests/ut/cpp/dataset/mixup_batch_op_test.cc b/tests/ut/cpp/dataset/mixup_batch_op_test.cc new file mode 100644 index 000000000..844566ec7 --- /dev/null +++ b/tests/ut/cpp/dataset/mixup_batch_op_test.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common.h" +#include "common/cvop_common.h" +#include "minddata/dataset/kernels/image/mixup_batch_op.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestMixUpBatchOp : public UT::CVOP::CVOpCommon { + protected: + MindDataTestMixUpBatchOp() : CVOpCommon() {} + + std::shared_ptr output_tensor_; +}; + +TEST_F(MindDataTestMixUpBatchOp, TestSuccess) { + MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp success case"; + std::shared_ptr batched_tensor; + std::shared_ptr batched_labels; + Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), input_tensor_->type(), &batched_tensor); + for (int i = 0; i < 2; i++) { + batched_tensor->InsertTensor({i}, input_tensor_); + } + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels); + std::shared_ptr op = std::make_shared(1); + TensorRow in; + in.push_back(batched_tensor); + in.push_back(batched_labels); + TensorRow out; + ASSERT_TRUE(op->Compute(in, &out).IsOk()); + + EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]); + EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]); + EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]); + EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]); + + EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]); + EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]); +} + +TEST_F(MindDataTestMixUpBatchOp, TestFail) { + // This is a fail case because our labels are not batched and are 1-dimensional + MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp fail case"; + std::shared_ptr labels; + Tensor::CreateFromVector(std::vector({0, 1, 1, 0}), TensorShape({4}), &labels); + std::shared_ptr op = std::make_shared(1); + TensorRow in; + in.push_back(input_tensor_); + in.push_back(labels); + TensorRow out; + ASSERT_FALSE(op->Compute(in, &out).IsOk()); +} diff --git a/tests/ut/data/dataset/golden/mixup_batch_c_result.npz b/tests/ut/data/dataset/golden/mixup_batch_c_result.npz new file mode 100644 index 0000000000000000000000000000000000000000..ad606d0d3c8b280d6af2d2da879819c9491ce36c GIT binary patch literal 713 zcmWIWW@Zs#fB;1XxfvUz9hn#yK$w$3gdwr0DBeIXub`5VK>#cWQV5a+fysWMz5$Vp z3}p<}>M5zk$wlf`3hFif>N*PQY57GZMTvRw`9&$IAYr$}oZ?iVcyUHzK`M~1VWgvA zq^YA&t3W>BYG6*zE6pva)Jx7UO4Z9P%_+$Qx;L?sE50Z-IX|zsq^LBxgsYGNqKYdo z1tMF>=*`et$mGnJRLI<3$P!e@s^QJ(&E(D0R>%fbno?3(kjhoa9s%;HzeOR3H-k50 zdm(2~A(w_Xa|9z$w5E{T&(F{6KM;TkZ~Kx$o}|v$LSBssR-k-lZen_BAzy4EzeWZ_ z2G~l044{32L4`sf`&e2Fg)<-)q?r_oKr9dqDiniRU{ffLY5_w@p+r)rvw-}mU7U@1 zC1)EwnZ=T}d)z9NEGd*qf>|OBvP34RP!?*5T!t7>&!@wETX!6p^7Cl&leNWl+-nQv zO9~Z|Qo%0GNlZ%3DO8LtRD!!y8Dxk`P@yW+5H+wN+56RsdF8D)2xQo*$Y^epEmSWl z)PP&83DTn#RHzNrqXX7c!k}DWU_bwrew5z6?U!$M73!80>LsP7mK5qI=>>Q*GKnzb fN`1gUgn|Z8fFsgkfHx}}NPrOt&4F|cI6VRYU0J=E literal 0 HcmV?d00001 diff --git a/tests/ut/python/dataset/test_mixup_op.py b/tests/ut/python/dataset/test_mixup_op.py new file mode 100644 index 000000000..9641a642f --- /dev/null +++ b/tests/ut/python/dataset/test_mixup_op.py @@ -0,0 +1,247 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing the MixUpBatch op in DE +""" +import numpy as np +import pytest +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +import mindspore.dataset.transforms.c_transforms as data_trans +from mindspore import log as logger +from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \ + config_get_set_num_parallel_workers + +DATA_DIR = "../data/dataset/testCifar10Data" + +GENERATE_GOLDEN = False + +def test_mixup_batch_success1(plot=False): + """ + Test MixUpBatch op with specified alpha parameter + """ + logger.info("test_mixup_batch_success1") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + mixup_batch_op = vision.MixUpBatch(2) + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op) + + images_mixup = None + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_mixup = image + else: + images_mixup = np.append(images_mixup, image, axis=0) + if plot: + visualize_list(images_original, images_mixup) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_mixup[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_mixup_batch_success2(plot=False): + """ + Test MixUpBatch op without specified alpha parameter. + Alpha parameter will be selected by default in this case + """ + logger.info("test_mixup_batch_success2") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + mixup_batch_op = vision.MixUpBatch() + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op) + + images_mixup = np.array([]) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_mixup = image + else: + images_mixup = np.append(images_mixup, image, axis=0) + if plot: + visualize_list(images_original, images_mixup) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = diff_mse(images_mixup[i], images_original[i]) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + +def test_mixup_batch_md5(): + """ + Test MixUpBatch with MD5: + """ + logger.info("test_mixup_batch_md5") + original_seed = config_get_set_seed(0) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # MixUp Images + data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data = data.map(input_columns=["label"], operations=one_hot_op) + mixup_batch_op = vision.MixUpBatch() + data = data.batch(5, drop_remainder=True) + data = data.map(input_columns=["image", "label"], operations=mixup_batch_op) + + filename = "mixup_batch_c_result.npz" + save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN) + + # Restore config setting + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_mixup_batch_fail1(): + """ + Test MixUpBatch Fail 1 + We expect this to fail because the images and labels are not batched + """ + logger.info("test_mixup_batch_fail1") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5) + + images_original = np.array([]) + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + mixup_batch_op = vision.MixUpBatch(0.1) + with pytest.raises(RuntimeError) as error: + data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_mixup = image + else: + images_mixup = np.append(images_mixup, image, axis=0) + error_message = "You must batch before calling MixUp" + assert error_message in str(error.value) + + +def test_mixup_batch_fail2(): + """ + Test MixUpBatch Fail 2 + We expect this to fail because alpha is negative + """ + logger.info("test_mixup_batch_fail2") + + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5) + + images_original = np.array([]) + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + with pytest.raises(ValueError) as error: + vision.MixUpBatch(-1) + error_message = "Input is not within the required interval" + assert error_message in str(error.value) + + +def test_mixup_batch_fail3(): + """ + Test MixUpBatch op + We expect this to fail because label column is not passed to mixup_batch + """ + # Original Images + ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + ds_original = ds_original.batch(5, drop_remainder=True) + + images_original = None + for idx, (image, _) in enumerate(ds_original): + if idx == 0: + images_original = image + else: + images_original = np.append(images_original, image, axis=0) + + # MixUp Images + data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False) + + one_hot_op = data_trans.OneHot(num_classes=10) + data1 = data1.map(input_columns=["label"], operations=one_hot_op) + mixup_batch_op = vision.MixUpBatch() + data1 = data1.batch(5, drop_remainder=True) + data1 = data1.map(input_columns=["image"], operations=mixup_batch_op) + + with pytest.raises(RuntimeError) as error: + images_mixup = np.array([]) + for idx, (image, _) in enumerate(data1): + if idx == 0: + images_mixup = image + else: + images_mixup = np.append(images_mixup, image, axis=0) + error_message = "Both images and labels columns are required" + assert error_message in str(error.value) + + +if __name__ == "__main__": + test_mixup_batch_success1(plot=True) + test_mixup_batch_success2(plot=True) + test_mixup_batch_md5() + test_mixup_batch_fail1() + test_mixup_batch_fail2() + test_mixup_batch_fail3() -- GitLab