提交 6a41b447 编写于 作者: E eric 提交者: Eric

Added wrapper around color change function

上级 85364d59
......@@ -35,6 +35,8 @@
#include "minddata/dataset/kernels/image/random_solarize_op.h"
#include "minddata/dataset/kernels/image/random_vertical_flip_op.h"
#include "minddata/dataset/kernels/image/resize_op.h"
#include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
#include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
#include "minddata/dataset/kernels/image/swap_red_blue_op.h"
#include "minddata/dataset/kernels/image/uniform_aug_op.h"
......@@ -240,6 +242,26 @@ std::shared_ptr<ResizeOperation> Resize(std::vector<int32_t> size, Interpolation
return op;
}
// Function to create RgbaToBgrOperation.
std::shared_ptr<RgbaToBgrOperation> RGBA2BGR() {
auto op = std::make_shared<RgbaToBgrOperation>();
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create RgbaToRgbOperation.
std::shared_ptr<RgbaToRgbOperation> RGBA2RGB() {
auto op = std::make_shared<RgbaToRgbOperation>();
// Input validation
if (!op->ValidateParams()) {
return nullptr;
}
return op;
}
// Function to create SwapRedBlueOperation.
std::shared_ptr<SwapRedBlueOperation> SwapRedBlue() {
auto op = std::make_shared<SwapRedBlueOperation>();
......@@ -743,6 +765,26 @@ std::shared_ptr<TensorOp> ResizeOperation::Build() {
return std::make_shared<ResizeOp>(height, width, interpolation_);
}
// RgbaToBgrOperation.
RgbaToBgrOperation::RgbaToBgrOperation() {}
bool RgbaToBgrOperation::ValidateParams() { return true; }
std::shared_ptr<TensorOp> RgbaToBgrOperation::Build() {
std::shared_ptr<RgbaToBgrOp> tensor_op = std::make_shared<RgbaToBgrOp>();
return tensor_op;
}
// RgbaToRgbOperation.
RgbaToRgbOperation::RgbaToRgbOperation() {}
bool RgbaToRgbOperation::ValidateParams() { return true; }
std::shared_ptr<TensorOp> RgbaToRgbOperation::Build() {
std::shared_ptr<RgbaToRgbOp> tensor_op = std::make_shared<RgbaToRgbOp>();
return tensor_op;
}
// SwapRedBlueOperation.
SwapRedBlueOperation::SwapRedBlueOperation() {}
......
......@@ -65,14 +65,16 @@ class RandomSharpnessOperation;
class RandomSolarizeOperation;
class RandomVerticalFlipOperation;
class ResizeOperation;
class RgbaToBgrOperation;
class RgbaToRgbOperation;
class SwapRedBlueOperation;
class UniformAugOperation;
/// \brief Function to create a CenterCrop TensorOperation.
/// \notes Crops the input image at the center to the given size.
/// \param[in] size - a vector representing the output size of the cropped image.
/// If size is a single value, a square crop of size (size, size) is returned.
/// If size has 2 values, it should be (height, width).
/// If size is a single value, a square crop of size (size, size) is returned.
/// If size has 2 values, it should be (height, width).
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<CenterCropOperation> CenterCrop(std::vector<int32_t> size);
......@@ -103,15 +105,15 @@ std::shared_ptr<HwcToChwOperation> HWC2CHW();
/// \brief Function to create a MixUpBatch TensorOperation.
/// \notes Apply MixUp transformation on an input batch of images and labels. The labels must be in one-hot format and
/// Batch must be called before calling this function.
/// Batch must be called before calling this function.
/// \param[in] alpha hyperparameter of beta distribution (default = 1.0)
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1);
/// \brief Function to create a Normalize TensorOperation.
/// \notes Normalize the input image with respect to mean and standard deviation.
/// \param[in] mean - a vector of mean values for each channel, w.r.t channel order.
/// \param[in] std - a vector of standard deviations for each channel, w.r.t. channel order.
/// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
/// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vector<float> std);
......@@ -230,8 +232,18 @@ std::shared_ptr<RandomSolarizeOperation> RandomSolarize(uint8_t threshold_min =
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RandomVerticalFlipOperation> RandomVerticalFlip(float prob = 0.5);
/// \brief Function to create a RgbaToBgr TensorOperation.
/// \notes Changes the input 4 channel RGBA tensor to 3 channel BGR.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RgbaToBgrOperation> RGBA2BGR();
/// \brief Function to create a RgbaToRgb TensorOperation.
/// \notes Changes the input 4 channel RGBA tensor to 3 channel RGB.
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<RgbaToRgbOperation> RGBA2RGB();
/// \brief Function to create a Resize TensorOperation.
/// \notes Resize the input image to the given size..
/// \notes Resize the input image to the given size.
/// \param[in] size - a vector representing the output size of the resized image.
/// If size is a single value, the image will be resized to this value with
/// the same image aspect ratio. If size has 2 values, it should be (height, width).
......@@ -520,6 +532,28 @@ class ResizeOperation : public TensorOperation {
InterpolationMode interpolation_;
};
class RgbaToBgrOperation : public TensorOperation {
public:
RgbaToBgrOperation();
~RgbaToBgrOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
};
class RgbaToRgbOperation : public TensorOperation {
public:
RgbaToRgbOperation();
~RgbaToRgbOperation() = default;
std::shared_ptr<TensorOp> Build() override;
bool ValidateParams() override;
};
class UniformAugOperation : public TensorOperation {
public:
explicit UniformAugOperation(std::vector<std::shared_ptr<TensorOperation>> transforms, int32_t num_ops = 2);
......
......@@ -36,6 +36,8 @@ add_library(kernels-image OBJECT
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
rgba_to_bgr_op.cc
rgba_to_rgb_op.cc
sharpness_op.cc
solarize_op.cc
swap_red_blue_op.cc
......
......@@ -856,6 +856,44 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
RETURN_STATUS_UNEXPECTED("Unexpected error in pad");
}
}
Status RgbaToRgb(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
int num_channels = input_cv->shape()[2];
if (input_cv->shape().Size() != 3 || num_channels != 4) {
std::string err_msg = "Number of channels does not equal 4, got : " + std::to_string(num_channels);
RETURN_STATUS_UNEXPECTED(err_msg);
}
TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1], 3});
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv));
cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(cv::COLOR_RGBA2RGB));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToRgb.");
}
}
Status RgbaToBgr(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
int num_channels = input_cv->shape()[2];
if (input_cv->shape().Size() != 3 || num_channels != 4) {
std::string err_msg = "Number of channels does not equal 4, got : " + std::to_string(num_channels);
RETURN_STATUS_UNEXPECTED(err_msg);
}
TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1], 3});
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv));
cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast<int>(cv::COLOR_RGBA2BGR));
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Unexpected error in RgbaToBgr.");
}
}
// -------- BBOX OPERATIONS -------- //
Status UpdateBBoxesForCrop(std::shared_ptr<Tensor> *bboxList, size_t *bboxCount, int CB_Xmin, int CB_Ymin, int CB_Xmax,
int CB_Ymax) {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status RgbaToBgrOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return RgbaToBgr(input, output);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_BGR_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_BGR_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RgbaToBgrOp : public TensorOp {
public:
RgbaToBgrOp() {}
~RgbaToBgrOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRgbaToBgrOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_BGR_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status RgbaToRgbOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
return RgbaToRgb(input, output);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_RGB_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_RGB_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RgbaToRgbOp : public TensorOp {
public:
RgbaToRgbOp() {}
~RgbaToRgbOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRgbaToRgbOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGBA_TO_RGB_OP_H_
......@@ -29,14 +29,7 @@ namespace mindspore {
namespace dataset {
class SwapRedBlueOp : public TensorOp {
public:
// SwapRedBlues the image to the output specified size. If only one value is provided,
// the it will crop the smaller size and maintains the aspect ratio.
// @param size1: the first size of output. If only this parameter is provided
// the smaller dimension will be cropd to this and then the other dimension changes
// such that the aspect ratio is maintained.
// @param size2: the second size of output. If this is also provided, the output size
// will be (size1, size2)
// @param InterpolationMode: the interpolation mode being used.
/// \brief Constructor
SwapRedBlueOp() {}
SwapRedBlueOp(const SwapRedBlueOp &rhs) = default;
......@@ -45,7 +38,7 @@ class SwapRedBlueOp : public TensorOp {
~SwapRedBlueOp() override = default;
void Print(std::ostream &out) const override { out << "SwapRedBlueOp x"; }
void Print(std::ostream &out) const override { out << "SwapRedBlueOp"; }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
......
......@@ -121,8 +121,10 @@ constexpr char kRescaleOp[] = "RescaleOp";
constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
constexpr char kResizeOp[] = "ResizeOp";
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
constexpr char kSolarizeOp[] = "SolarizeOp";
constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";
constexpr char kSharpnessOp[] = "SharpnessOp";
constexpr char kSolarizeOp[] = "SolarizeOp";
constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp";
constexpr char kUniformAugOp[] = "UniformAugOp";
constexpr char kSoftDvppDecodeRandomCropResizeJpegOp[] = "SoftDvppDecodeRandomCropResizeJpegOp";
......
......@@ -55,12 +55,14 @@ SET(DE_UT_SRCS
random_vertical_flip_with_bbox_op_test.cc
rename_op_test.cc
repeat_op_test.cc
skip_op_test.cc
rescale_op_test.cc
resize_bilinear_op_test.cc
resize_op_test.cc
resize_with_bbox_op_test.cc
rgba_to_bgr_op_test.cc
rgba_to_rgb_op_test.cc
schema_test.cc
skip_op_test.cc
shuffle_op_test.cc
stand_alone_samplers_test.cc
status_test.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <opencv2/imgcodecs.hpp>
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestRgbaToBgrOp : public UT::CVOP::CVOpCommon {
protected:
MindDataTestRgbaToBgrOp() : CVOpCommon() {}
std::shared_ptr<Tensor> output_tensor_;
};
TEST_F(MindDataTestRgbaToBgrOp, TestOp1) {
MS_LOG(INFO) << "Doing testRGBA2BGR.";
std::unique_ptr<RgbaToBgrOp> op(new RgbaToBgrOp());
EXPECT_TRUE(op->OneToOne());
// prepare 4 channel image
cv::Mat rgba_image;
// First create the image with alpha channel
cv::cvtColor(raw_cv_image_, rgba_image, cv::COLOR_BGR2RGBA);
std::vector<cv::Mat>channels(4);
cv::split(rgba_image, channels);
channels[3] = cv::Mat::zeros(rgba_image.rows, rgba_image.cols, CV_8UC1);
cv::merge(channels, rgba_image);
// create new tensor to test conversion
std::shared_ptr<Tensor> rgba_input;
std::shared_ptr<CVTensor> input_cv_tensor;
CVTensor::CreateFromMat(rgba_image, &input_cv_tensor);
rgba_input = std::dynamic_pointer_cast<Tensor>(input_cv_tensor);
Status s = op->Compute(rgba_input, &output_tensor_);
size_t actual = 0;
if (s == Status::OK()) {
actual = output_tensor_->shape()[0] * output_tensor_->shape()[1] * output_tensor_->shape()[2];
}
EXPECT_EQ(actual, input_tensor_->shape()[0] * input_tensor_->shape()[1] * 3);
EXPECT_EQ(s, Status::OK());
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <opencv2/imgcodecs.hpp>
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestRgbaToRgbOp : public UT::CVOP::CVOpCommon {
protected:
MindDataTestRgbaToRgbOp() : CVOpCommon() {}
std::shared_ptr<Tensor> output_tensor_;
};
TEST_F(MindDataTestRgbaToRgbOp, TestOp1) {
MS_LOG(INFO) << "Doing testRGBA2RGB.";
std::unique_ptr<RgbaToRgbOp> op(new RgbaToRgbOp());
EXPECT_TRUE(op->OneToOne());
// prepare 4 channel image
cv::Mat rgba_image;
// First create the image with alpha channel
cv::cvtColor(raw_cv_image_, rgba_image, cv::COLOR_BGR2RGBA);
std::vector<cv::Mat>channels(4);
cv::split(rgba_image, channels);
channels[3] = cv::Mat::zeros(rgba_image.rows, rgba_image.cols, CV_8UC1);
cv::merge(channels, rgba_image);
// create new tensor to test conversion
std::shared_ptr<Tensor> rgba_input;
std::shared_ptr<CVTensor> input_cv_tensor;
CVTensor::CreateFromMat(rgba_image, &input_cv_tensor);
rgba_input = std::dynamic_pointer_cast<Tensor>(input_cv_tensor);
Status s = op->Compute(rgba_input, &output_tensor_);
size_t actual = 0;
if (s == Status::OK()) {
actual = output_tensor_->shape()[0] * output_tensor_->shape()[1] * output_tensor_->shape()[2];
}
EXPECT_EQ(actual, input_tensor_->shape()[0] * input_tensor_->shape()[1] * 3);
EXPECT_EQ(s, Status::OK());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册