提交 1896950a 编写于 作者: M Mahdi

Added Mixup

上级 729d847d
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h" #include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#include "minddata/dataset/kernels/image/image_utils.h" #include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/invert_op.h" #include "minddata/dataset/kernels/image/invert_op.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
...@@ -92,6 +93,12 @@ PYBIND_REGISTER(CenterCropOp, 1, ([](const py::module *m) { ...@@ -92,6 +93,12 @@ PYBIND_REGISTER(CenterCropOp, 1, ([](const py::module *m) {
.def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth); .def(py::init<int32_t, int32_t>(), py::arg("height"), py::arg("width") = CenterCropOp::kDefWidth);
})); }));
PYBIND_REGISTER(MixUpBatchOp, 1, ([](const py::module *m) {
(void)py::class_<MixUpBatchOp, TensorOp, std::shared_ptr<MixUpBatchOp>>(
*m, "MixUpBatchOp", "Tensor operation to mixup a batch of images")
.def(py::init<float>(), py::arg("alpha"));
}));
PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) { PYBIND_REGISTER(ResizeOp, 1, ([](const py::module *m) {
(void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>( (void)py::class_<ResizeOp, TensorOp, std::shared_ptr<ResizeOp>>(
*m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode") *m, "ResizeOp", "Tensor operation to resize an image. Takes height, width and mode")
......
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "minddata/dataset/kernels/image/crop_op.h" #include "minddata/dataset/kernels/image/crop_op.h"
#include "minddata/dataset/kernels/image/cut_out_op.h" #include "minddata/dataset/kernels/image/cut_out_op.h"
#include "minddata/dataset/kernels/image/decode_op.h" #include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h" #include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/data/one_hot_op.h"
#include "minddata/dataset/kernels/image/pad_op.h" #include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h" #include "minddata/dataset/kernels/image/random_color_adjust_op.h"
#include "minddata/dataset/kernels/image/random_crop_op.h" #include "minddata/dataset/kernels/image/random_crop_op.h"
...@@ -81,6 +83,16 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) { ...@@ -81,6 +83,16 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb) {
return op; return op;
} }
// Function to create MixUpBatchOperation.
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha) {
auto op = std::make_shared<MixUpBatchOperation>(alpha);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create NormalizeOperation. // Function to create NormalizeOperation.
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) { std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std) {
auto op = std::make_shared<NormalizeOperation>(mean, std); auto op = std::make_shared<NormalizeOperation>(mean, std);
...@@ -91,6 +103,16 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect ...@@ -91,6 +103,16 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
return op; return op;
} }
// Function to create OneHotOperation.
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes) {
auto op = std::make_shared<OneHotOperation>(num_classes);
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create PadOperation. // Function to create PadOperation.
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
BorderType padding_mode) { BorderType padding_mode) {
...@@ -271,6 +293,20 @@ bool DecodeOperation::ValidateParams() { return true; } ...@@ -271,6 +293,20 @@ bool DecodeOperation::ValidateParams() { return true; }
std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); } std::shared_ptr<TensorOp> DecodeOperation::Build() { return std::make_shared<DecodeOp>(rgb_); }
// MixUpOperation
MixUpBatchOperation::MixUpBatchOperation(float alpha) : alpha_(alpha) {}
bool MixUpBatchOperation::ValidateParams() {
if (alpha_ < 0) {
MS_LOG(ERROR) << "MixUpBatch: alpha must be a positive floating value however it is: " << alpha_;
return false;
}
return true;
}
std::shared_ptr<TensorOp> MixUpBatchOperation::Build() { return std::make_shared<MixUpBatchOp>(alpha_); }
// NormalizeOperation // NormalizeOperation
NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {} NormalizeOperation::NormalizeOperation(std::vector<float> mean, std::vector<float> std) : mean_(mean), std_(std) {}
...@@ -292,6 +328,20 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() { ...@@ -292,6 +328,20 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); return std::make_shared<NormalizeOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]);
} }
// OneHotOperation
OneHotOperation::OneHotOperation(int32_t num_classes) : num_classes_(num_classes) {}
bool OneHotOperation::ValidateParams() {
if (num_classes_ < 0) {
MS_LOG(ERROR) << "OneHot: Number of classes cannot be negative. Number of classes: " << num_classes_;
return false;
}
return true;
}
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
// PadOperation // PadOperation
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode) PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {} : padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}
......
...@@ -51,7 +51,9 @@ class CenterCropOperation; ...@@ -51,7 +51,9 @@ class CenterCropOperation;
class CropOperation; class CropOperation;
class CutOutOperation; class CutOutOperation;
class DecodeOperation; class DecodeOperation;
class MixUpBatchOperation;
class NormalizeOperation; class NormalizeOperation;
class OneHotOperation;
class PadOperation; class PadOperation;
class RandomColorAdjustOperation; class RandomColorAdjustOperation;
class RandomCropOperation; class RandomCropOperation;
...@@ -90,6 +92,13 @@ std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1) ...@@ -90,6 +92,13 @@ std::shared_ptr<CutOutOperation> CutOut(int32_t length, int32_t num_patches = 1)
/// \return Shared pointer to the current TensorOperation. /// \return Shared pointer to the current TensorOperation.
std::shared_ptr<DecodeOperation> Decode(bool rgb = true); std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
/// \brief Function to create a MixUpBatch TensorOperation.
/// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and
/// Batch must be called before calling this function.
/// \param[in] alpha hyperparameter of beta distribution (default = 1.0)
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1);
/// \brief Function to create a Normalize TensorOperation. /// \brief Function to create a Normalize TensorOperation.
/// \notes Normalize the input image with respect to mean and standard deviation. /// \notes Normalize the input image with respect to mean and standard deviation.
/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order. /// \param[in] mean - a vector of mean values for each channel, w.r.t channel order.
...@@ -97,6 +106,12 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true); ...@@ -97,6 +106,12 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
/// \return Shared pointer to the current TensorOperation. /// \return Shared pointer to the current TensorOperation.
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std); std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std);
/// \brief Function to create a OneHot TensorOperation.
/// \notes Convert the labels into OneHot format.
/// \param[in] num_classes number of classes.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<OneHotOperation> OneHot(int32_t num_classes);
/// \brief Function to create a Pad TensorOp /// \brief Function to create a Pad TensorOp
/// \notes Pads the image according to padding parameters /// \notes Pads the image according to padding parameters
/// \param[in] padding A vector representing the number of pixels to pad the image /// \param[in] padding A vector representing the number of pixels to pad the image
...@@ -258,6 +273,20 @@ class DecodeOperation : public TensorOperation { ...@@ -258,6 +273,20 @@ class DecodeOperation : public TensorOperation {
bool rgb_; bool rgb_;
}; };
class MixUpBatchOperation : public TensorOperation {
public:
explicit MixUpBatchOperation(float alpha = 1);
~MixUpBatchOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
float alpha_;
};
class NormalizeOperation : public TensorOperation { class NormalizeOperation : public TensorOperation {
public: public:
NormalizeOperation(std::vector<float> mean, std::vector<float> std); NormalizeOperation(std::vector<float> mean, std::vector<float> std);
...@@ -273,6 +302,20 @@ class NormalizeOperation : public TensorOperation { ...@@ -273,6 +302,20 @@ class NormalizeOperation : public TensorOperation {
std::vector<float> std_; std::vector<float> std_;
}; };
class OneHotOperation : public TensorOperation {
public:
explicit OneHotOperation(int32_t num_classes_);
~OneHotOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
private:
float num_classes_;
};
class PadOperation : public TensorOperation { class PadOperation : public TensorOperation {
public: public:
PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0}, PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <limits> #include <limits>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "minddata/dataset/core/constants.h" #include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/data_type.h"
...@@ -648,5 +649,30 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std:: ...@@ -648,5 +649,30 @@ Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::
return Status::OK(); return Status::OK();
} }
Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
std::vector<std::shared_ptr<CVTensor>> *output) {
std::vector<int64_t> tensor_shape = input->shape().AsVector();
TensorShape remaining({-1});
std::vector<int64_t> index(tensor_shape.size(), 0);
if (tensor_shape.size() <= 1) {
RETURN_STATUS_UNEXPECTED("Tensor must be at least 2-D in order to unpack");
}
TensorShape element_shape(std::vector<int64_t>(tensor_shape.begin() + 1, tensor_shape.end()));
for (; index[0] < tensor_shape[0]; index[0]++) {
uchar *start_addr_of_index = nullptr;
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(input->StartAddrOfIndex(index, &start_addr_of_index, &remaining));
RETURN_IF_NOT_OK(input->CreateFromMemory(element_shape, input->type(), start_addr_of_index, &out));
std::shared_ptr<CVTensor> cv_out = CVTensor::AsCVTensor(std::move(out));
if (!cv_out->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
output->push_back(cv_out);
}
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
...@@ -152,6 +152,17 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu ...@@ -152,6 +152,17 @@ Status Mask(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend, Status Concatenate(const TensorRow &input, TensorRow *output, int8_t axis, std::shared_ptr<Tensor> prepend,
std::shared_ptr<Tensor> append); std::shared_ptr<Tensor> append);
// helper for concat, always append to the input, and pass that to the output
Status ConcatenateHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int8_t axis,
std::shared_ptr<Tensor> append);
/// Convert an n-dimensional Tensor to a vector of (n-1)-dimensional CVTensors
/// @param input[in] input tensor
/// @param output[out] output tensor
/// @return Status ok/error
Status BatchTensorToCVTensorVector(const std::shared_ptr<Tensor> &input,
std::vector<std::shared_ptr<CVTensor>> *output);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
......
...@@ -10,6 +10,7 @@ add_library(kernels-image OBJECT ...@@ -10,6 +10,7 @@ add_library(kernels-image OBJECT
hwc_to_chw_op.cc hwc_to_chw_op.cc
image_utils.cc image_utils.cc
invert_op.cc invert_op.cc
mixup_batch_op.cc
normalize_op.cc normalize_op.cc
pad_op.cc pad_op.cc
random_color_adjust_op.cc random_color_adjust_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <utility>
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
MixUpBatchOp::MixUpBatchOp(float alpha) : alpha_(alpha) { rnd_.seed(GetSeed()); }
Status MixUpBatchOp::Compute(const TensorRow &input, TensorRow *output) {
if (input.size() < 2) {
RETURN_STATUS_UNEXPECTED("Both images and labels columns are required for this operation");
}
std::vector<std::shared_ptr<CVTensor>> images;
std::vector<int64_t> image_shape = input.at(0)->shape().AsVector();
std::vector<int64_t> label_shape = input.at(1)->shape().AsVector();
// Check inputs
if (label_shape.size() != 2 || image_shape.size() != 4 || image_shape[0] != label_shape[0]) {
RETURN_STATUS_UNEXPECTED("You must batch before calling MixUpBatch");
}
if ((image_shape[1] != 1 && image_shape[1] != 3) && (image_shape[3] != 1 && image_shape[3] != 3)) {
RETURN_STATUS_UNEXPECTED("MixUpBatch: Images must be in the shape of HWC or CHW");
}
// Move images into a vector of CVTensors
RETURN_IF_NOT_OK(BatchTensorToCVTensorVector(input.at(0), &images));
// Calculating lambda
// If x1 is a random variable from Gamma(a1, 1) and x2 is a random variable from Gamma(a2, 1)
// then x = x1 / (x1+x2) is a random variable from Beta(a1, a2)
std::gamma_distribution<float> distribution(alpha_, 1);
float x1 = distribution(rnd_);
float x2 = distribution(rnd_);
float lam = x1 / (x1 + x2);
// Calculate random labels
std::vector<int64_t> rand_indx;
for (int64_t i = 0; i < images.size(); i++) rand_indx.push_back(i);
std::shuffle(rand_indx.begin(), rand_indx.end(), rnd_);
// Compute labels
std::shared_ptr<Tensor> out_labels;
RETURN_IF_NOT_OK(TypeCast(std::move(input.at(1)), &out_labels, DataType("float32")));
for (int64_t i = 0; i < label_shape[0]; i++) {
for (int64_t j = 0; j < label_shape[1]; j++) {
uint64_t first_value, second_value;
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&first_value, {i, j}));
RETURN_IF_NOT_OK(input.at(1)->GetItemAt(&second_value, {rand_indx[i], j}));
RETURN_IF_NOT_OK(out_labels->SetItemAt({i, j}, lam * first_value + (1 - lam) * second_value));
}
}
// Compute images
for (int64_t i = 0; i < images.size(); i++) {
TensorShape remaining({-1});
uchar *start_addr_of_index = nullptr;
std::shared_ptr<Tensor> out;
RETURN_IF_NOT_OK(input.at(0)->StartAddrOfIndex({rand_indx[i], 0, 0, 0}, &start_addr_of_index, &remaining));
RETURN_IF_NOT_OK(input.at(0)->CreateFromMemory(TensorShape({image_shape[1], image_shape[2], image_shape[3]}),
input.at(0)->type(), start_addr_of_index, &out));
std::shared_ptr<CVTensor> rand_image = CVTensor::AsCVTensor(std::move(out));
if (!rand_image->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
images[i]->mat() = lam * images[i]->mat() + (1 - lam) * rand_image->mat();
}
// Move the output into a TensorRow
std::shared_ptr<Tensor> output_image;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input.at(0)->shape(), input.at(0)->type(), &output_image));
for (int64_t i = 0; i < images.size(); i++) {
RETURN_IF_NOT_OK(output_image->InsertTensor({i}, images[i]));
}
output->push_back(output_image);
output->push_back(out_labels);
return Status::OK();
}
void MixUpBatchOp::Print(std::ostream &out) const {
out << "MixUpBatchOp: "
<< "alpha: " << alpha_ << "\n";
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
#include <memory>
#include <vector>
#include <random>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class MixUpBatchOp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
explicit MixUpBatchOp(float alpha);
~MixUpBatchOp() override = default;
void Print(std::ostream &out) const override;
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kMixUpBatchOp; }
private:
float alpha_;
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_MIXUPBATCH_OP_H_
...@@ -99,6 +99,7 @@ constexpr char kCropOp[] = "CropOp"; ...@@ -99,6 +99,7 @@ constexpr char kCropOp[] = "CropOp";
constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kEqualizeOp[] = "EqualizeOp";
constexpr char kHwcToChwOp[] = "HwcToChwOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp";
constexpr char kInvertOp[] = "InvertOp"; constexpr char kInvertOp[] = "InvertOp";
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
constexpr char kNormalizeOp[] = "NormalizeOp"; constexpr char kNormalizeOp[] = "NormalizeOp";
constexpr char kPadOp[] = "PadOp"; constexpr char kPadOp[] = "PadOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp"; constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
......
...@@ -45,9 +45,9 @@ import mindspore._c_dataengine as cde ...@@ -45,9 +45,9 @@ import mindspore._c_dataengine as cde
from .utils import Inter, Border from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, check_range, \ check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, check_bounding_box_augment_cpp, \ check_range, check_resize, check_rescale, check_pad, check_cutout, check_uniform_augment_cpp, \
check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, FLOAT_MAX_INTEGER
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
...@@ -130,6 +130,30 @@ class CutOut(cde.CutOutOp): ...@@ -130,6 +130,30 @@ class CutOut(cde.CutOutOp):
super().__init__(length, length, num_patches, False, *fill_value) super().__init__(length, length, num_patches, False, *fill_value)
class MixUpBatch(cde.MixUpBatchOp):
"""
Apply MixUp transformation on input batch of images and labels. Each image is multiplied by a random weight (lambda)
and then added to a randomly selected image from the batch multiplied by (1 - lambda). Same formula is also applied
to the one-hot labels.
Note that you need to make labels into one-hot format and batch before calling this function.
Args:
alpha (float): hyperparameter of beta distribution (default = 1.0).
Examples:
>>> one_hot_op = data.OneHot(num_classes=10)
>>> data = data.map(input_columns=["label"], operations=one_hot_op)
>>> mixup_batch_op = vision.MixUpBatch()
>>> data = data.batch(5)
>>> data = data.map(input_columns=["image", "label"], operations=mixup_batch_op)
"""
@check_mix_up_batch_c
def __init__(self, alpha=1.0):
self.alpha = alpha
super().__init__(alpha)
class Normalize(cde.NormalizeOp): class Normalize(cde.NormalizeOp):
""" """
Normalize the input image with respect to mean and standard deviation. Normalize the input image with respect to mean and standard deviation.
......
...@@ -47,6 +47,19 @@ def check_resize_size(size): ...@@ -47,6 +47,19 @@ def check_resize_size(size):
raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.") raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
def check_mix_up_batch_c(method):
"""Wrapper method to check the parameters of MixUpBatch."""
@wraps(method)
def new_method(self, *args, **kwargs):
[alpha], _ = parse_user_args(method, *args, **kwargs)
check_pos_float32(alpha)
return method(self, *args, **kwargs)
return new_method
def check_normalize_c_param(mean, std): def check_normalize_c_param(mean, std):
if len(mean) != len(std): if len(mean) != len(std):
raise ValueError("Length of mean and std must be equal") raise ValueError("Length of mean and std must be equal")
......
...@@ -27,6 +27,7 @@ SET(DE_UT_SRCS ...@@ -27,6 +27,7 @@ SET(DE_UT_SRCS
main_test.cc main_test.cc
map_op_test.cc map_op_test.cc
mind_record_op_test.cc mind_record_op_test.cc
mixup_batch_op_test.cc
memory_pool_test.cc memory_pool_test.cc
normalize_op_test.cc normalize_op_test.cc
one_hot_op_test.cc one_hot_op_test.cc
......
...@@ -146,6 +146,127 @@ TEST_F(MindDataTestPipeline, TestRandomFlip) { ...@@ -146,6 +146,127 @@ TEST_F(MindDataTestPipeline, TestRandomFlip) {
iter->Stop(); iter->Stop();
} }
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 5;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(0.5);
EXPECT_NE(mixup_batch_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({mixup_batch_op}, {"image", "label"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 5;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch();
EXPECT_NE(mixup_batch_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({mixup_batch_op}, {"image", "label"});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) {
// Create a Cifar10 Dataset
std::string folder_path = datasets_root_path_ + "/testCifar10Data/";
std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds
int32_t batch_size = 5;
ds = ds->Batch(batch_size);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> one_hot_op = vision::OneHot(10);
EXPECT_NE(one_hot_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({one_hot_op},{"label"});
EXPECT_NE(ds, nullptr);
std::shared_ptr<TensorOperation> mixup_batch_op = vision::MixUpBatch(-1);
EXPECT_EQ(mixup_batch_op, nullptr);
}
TEST_F(MindDataTestPipeline, TestPad) { TEST_F(MindDataTestPipeline, TestPad) {
// Create an ImageFolder Dataset // Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/"; std::string folder_path = datasets_root_path_ + "/testPK/data/";
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestMixUpBatchOp : public UT::CVOP::CVOpCommon {
protected:
MindDataTestMixUpBatchOp() : CVOpCommon() {}
std::shared_ptr<Tensor> output_tensor_;
};
TEST_F(MindDataTestMixUpBatchOp, TestSuccess) {
MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp success case";
std::shared_ptr<Tensor> batched_tensor;
std::shared_ptr<Tensor> batched_labels;
Tensor::CreateEmpty(TensorShape({2, input_tensor_->shape()[0], input_tensor_->shape()[1], input_tensor_->shape()[2]}), input_tensor_->type(), &batched_tensor);
for (int i = 0; i < 2; i++) {
batched_tensor->InsertTensor({i}, input_tensor_);
}
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({2, 2}), &batched_labels);
std::shared_ptr<MixUpBatchOp> op = std::make_shared<MixUpBatchOp>(1);
TensorRow in;
in.push_back(batched_tensor);
in.push_back(batched_labels);
TensorRow out;
ASSERT_TRUE(op->Compute(in, &out).IsOk());
EXPECT_EQ(in.at(0)->shape()[0], out.at(0)->shape()[0]);
EXPECT_EQ(in.at(0)->shape()[1], out.at(0)->shape()[1]);
EXPECT_EQ(in.at(0)->shape()[2], out.at(0)->shape()[2]);
EXPECT_EQ(in.at(0)->shape()[3], out.at(0)->shape()[3]);
EXPECT_EQ(in.at(1)->shape()[0], out.at(1)->shape()[0]);
EXPECT_EQ(in.at(1)->shape()[1], out.at(1)->shape()[1]);
}
TEST_F(MindDataTestMixUpBatchOp, TestFail) {
// This is a fail case because our labels are not batched and are 1-dimensional
MS_LOG(INFO) << "Doing MindDataTestMixUpBatchOp fail case";
std::shared_ptr<Tensor> labels;
Tensor::CreateFromVector(std::vector<uint32_t>({0, 1, 1, 0}), TensorShape({4}), &labels);
std::shared_ptr<MixUpBatchOp> op = std::make_shared<MixUpBatchOp>(1);
TensorRow in;
in.push_back(input_tensor_);
in.push_back(labels);
TensorRow out;
ASSERT_FALSE(op->Compute(in, &out).IsOk());
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing the MixUpBatch op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
import mindspore.dataset.transforms.c_transforms as data_trans
from mindspore import log as logger
from util import save_and_check_md5, diff_mse, visualize_list, config_get_set_seed, \
config_get_set_num_parallel_workers
DATA_DIR = "../data/dataset/testCifar10Data"
GENERATE_GOLDEN = False
def test_mixup_batch_success1(plot=False):
"""
Test MixUpBatch op with specified alpha parameter
"""
logger.info("test_mixup_batch_success1")
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5, drop_remainder=True)
images_original = None
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch(2)
data1 = data1.batch(5, drop_remainder=True)
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
images_mixup = None
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_mixup = image
else:
images_mixup = np.append(images_mixup, image, axis=0)
if plot:
visualize_list(images_original, images_mixup)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_mixup[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_mixup_batch_success2(plot=False):
"""
Test MixUpBatch op without specified alpha parameter.
Alpha parameter will be selected by default in this case
"""
logger.info("test_mixup_batch_success2")
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5, drop_remainder=True)
images_original = None
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch()
data1 = data1.batch(5, drop_remainder=True)
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
images_mixup = np.array([])
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_mixup = image
else:
images_mixup = np.append(images_mixup, image, axis=0)
if plot:
visualize_list(images_original, images_mixup)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_mixup[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_mixup_batch_md5():
"""
Test MixUpBatch with MD5:
"""
logger.info("test_mixup_batch_md5")
original_seed = config_get_set_seed(0)
original_num_parallel_workers = config_get_set_num_parallel_workers(1)
# MixUp Images
data = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data = data.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch()
data = data.batch(5, drop_remainder=True)
data = data.map(input_columns=["image", "label"], operations=mixup_batch_op)
filename = "mixup_batch_c_result.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
# Restore config setting
ds.config.set_seed(original_seed)
ds.config.set_num_parallel_workers(original_num_parallel_workers)
def test_mixup_batch_fail1():
"""
Test MixUpBatch Fail 1
We expect this to fail because the images and labels are not batched
"""
logger.info("test_mixup_batch_fail1")
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5)
images_original = np.array([])
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch(0.1)
with pytest.raises(RuntimeError) as error:
data1 = data1.map(input_columns=["image", "label"], operations=mixup_batch_op)
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_mixup = image
else:
images_mixup = np.append(images_mixup, image, axis=0)
error_message = "You must batch before calling MixUp"
assert error_message in str(error.value)
def test_mixup_batch_fail2():
"""
Test MixUpBatch Fail 2
We expect this to fail because alpha is negative
"""
logger.info("test_mixup_batch_fail2")
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5)
images_original = np.array([])
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
with pytest.raises(ValueError) as error:
vision.MixUpBatch(-1)
error_message = "Input is not within the required interval"
assert error_message in str(error.value)
def test_mixup_batch_fail3():
"""
Test MixUpBatch op
We expect this to fail because label column is not passed to mixup_batch
"""
# Original Images
ds_original = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
ds_original = ds_original.batch(5, drop_remainder=True)
images_original = None
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original, image, axis=0)
# MixUp Images
data1 = ds.Cifar10Dataset(DATA_DIR, num_samples=10, shuffle=False)
one_hot_op = data_trans.OneHot(num_classes=10)
data1 = data1.map(input_columns=["label"], operations=one_hot_op)
mixup_batch_op = vision.MixUpBatch()
data1 = data1.batch(5, drop_remainder=True)
data1 = data1.map(input_columns=["image"], operations=mixup_batch_op)
with pytest.raises(RuntimeError) as error:
images_mixup = np.array([])
for idx, (image, _) in enumerate(data1):
if idx == 0:
images_mixup = image
else:
images_mixup = np.append(images_mixup, image, axis=0)
error_message = "Both images and labels columns are required"
assert error_message in str(error.value)
if __name__ == "__main__":
test_mixup_batch_success1(plot=True)
test_mixup_batch_success2(plot=True)
test_mixup_batch_md5()
test_mixup_batch_fail1()
test_mixup_batch_fail2()
test_mixup_batch_fail3()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册