提交 35c3a637 编写于 作者: A avakh

support cpp invert operation

上级 60927ef1
......@@ -54,6 +54,7 @@
#include "minddata/dataset/kernels/image/decode_op.h"
#include "minddata/dataset/kernels/image/hwc_to_chw_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/kernels/image/invert_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_color_adjust_op.h"
......@@ -362,6 +363,10 @@ void bindTensorOps1(py::module *m) {
.def(py::init<float, float, float, float, float, float>(), py::arg("meanR"), py::arg("meanG"), py::arg("meanB"),
py::arg("stdR"), py::arg("stdG"), py::arg("stdB"));
(void)py::class_<InvertOp, TensorOp, std::shared_ptr<InvertOp>>(*m, "InvertOp",
"Tensor operation to apply invert on RGB images.")
.def(py::init<>());
(void)py::class_<RescaleOp, TensorOp, std::shared_ptr<RescaleOp>>(
*m, "RescaleOp", "Tensor operation to rescale an image. Takes scale and shift.")
.def(py::init<float, float>(), py::arg("rescale"), py::arg("shift"));
......
......@@ -6,6 +6,7 @@ add_library(kernels-image OBJECT
decode_op.cc
hwc_to_chw_op.cc
image_utils.cc
invert_op.cc
normalize_op.cc
pad_op.cc
random_color_adjust_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/invert_op.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// only supports RGB images
Status InvertOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
cv::Mat input_img = input_cv->mat();
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
if (input_cv->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("Shape not <H,W,C>");
}
int num_channels = input_cv->shape()[2];
if (num_channels != 3) {
RETURN_STATUS_UNEXPECTED("The shape is incorrect: num of channels != 3");
}
auto output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv);
output_cv->mat() = cv::Scalar::all(255) - input_img;
*output = std::static_pointer_cast<Tensor>(output_cv);
}
catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Error in invert");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_INVERT_OP_H
#define DATASET_KERNELS_IMAGE_INVERT_OP_H
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class InvertOp : public TensorOp {
public:
InvertOp() {}
~InvertOp() = default;
// Description: A function that prints info about the node
void Print(std::ostream &out) const override { out << Name(); }
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kInvertOp; }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_IMAGE_INVERT_OP_H
......@@ -92,6 +92,7 @@ constexpr char kDecodeOp[] = "DecodeOp";
constexpr char kCenterCropOp[] = "CenterCropOp";
constexpr char kCutOutOp[] = "CutOutOp";
constexpr char kHwcToChwOp[] = "HwcToChwOp";
constexpr char kInvertOp[] = "InvertOp";
constexpr char kNormalizeOp[] = "NormalizeOp";
constexpr char kPadOp[] = "PadOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
......
......@@ -71,6 +71,13 @@ def parse_padding(padding):
return padding
class Invert(cde.InvertOp):
"""
Apply invert on input image in RGB mode.
does not have input arguments.
"""
class Decode(cde.DecodeOp):
"""
Decode the input image in RGB mode.
......
......@@ -19,18 +19,20 @@ import numpy as np
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
import mindspore.dataset.transforms.vision.c_transforms as C
from mindspore import log as logger
from util import visualize_list, save_and_check_md5
from util import visualize_list, save_and_check_md5, diff_mse
DATA_DIR = "../data/dataset/testImageNetData/train/"
GENERATE_GOLDEN = False
def test_invert(plot=False):
def test_invert_py(plot=False):
"""
Test Invert
Test Invert python op
"""
logger.info("Test Invert")
logger.info("Test Invert Python op")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
......@@ -52,7 +54,7 @@ def test_invert(plot=False):
np.transpose(image, (0, 2, 3, 1)),
axis=0)
# Color Inverted Images
# Color Inverted Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_invert = F.ComposeOp([F.Decode(),
......@@ -83,11 +85,143 @@ def test_invert(plot=False):
visualize_list(images_original, images_invert)
def test_invert_md5():
def test_invert_c(plot=False):
"""
Test Invert Cpp op
"""
logger.info("Test Invert cpp op")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = [C.Decode(), C.Resize(size=[224, 224])]
ds_original = ds.map(input_columns="image",
operations=transforms_original)
ds_original = ds_original.batch(512)
for idx, (image, _) in enumerate(ds_original):
if idx == 0:
images_original = image
else:
images_original = np.append(images_original,
image,
axis=0)
# Invert Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transform_invert = [C.Decode(), C.Resize(size=[224, 224]),
C.Invert()]
ds_invert = ds.map(input_columns="image",
operations=transform_invert)
ds_invert = ds_invert.batch(512)
for idx, (image, _) in enumerate(ds_invert):
if idx == 0:
images_invert = image
else:
images_invert = np.append(images_invert,
image,
axis=0)
if plot:
visualize_list(images_original, images_invert)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_invert[i], images_original[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
def test_invert_py_c(plot=False):
"""
Test Invert Cpp op and python op
"""
logger.info("Test Invert cpp and python op")
# Invert Images in cpp
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(), C.Resize((224, 224))])
ds_c_invert = ds.map(input_columns="image",
operations=C.Invert())
ds_c_invert = ds_c_invert.batch(512)
for idx, (image, _) in enumerate(ds_c_invert):
if idx == 0:
images_c_invert = image
else:
images_c_invert = np.append(images_c_invert,
image,
axis=0)
# invert images in python
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(), C.Resize((224, 224))])
transforms_p_invert = F.ComposeOp([lambda img: img.astype(np.uint8),
F.ToPIL(),
F.Invert(),
np.array])
ds_p_invert = ds.map(input_columns="image",
operations=transforms_p_invert())
ds_p_invert = ds_p_invert.batch(512)
for idx, (image, _) in enumerate(ds_p_invert):
if idx == 0:
images_p_invert = image
else:
images_p_invert = np.append(images_p_invert,
image,
axis=0)
num_samples = images_c_invert.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = diff_mse(images_p_invert[i], images_c_invert[i])
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize_list(images_c_invert, images_p_invert, visualize_mode=2)
def test_invert_one_channel():
"""
Test Invert with md5 check
Test Invert cpp op with one channel image
"""
logger.info("Test Invert C Op With One Channel Images")
c_op = C.Invert()
try:
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
ds = ds.map(input_columns=["image"],
operations=[C.Decode(),
C.Resize((224, 224)),
lambda img: np.array(img[:, :, 0])])
ds.map(input_columns="image",
operations=c_op)
except RuntimeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "The shape" in str(e)
def test_invert_md5_py():
"""
logger.info("Test Invert with md5 check")
Test Invert python op with md5 check
"""
logger.info("Test Invert python op with md5 check")
# Generate dataset
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
......@@ -98,10 +232,34 @@ def test_invert_md5():
data = ds.map(input_columns="image", operations=transforms_invert())
# Compare with expected md5 from images
filename = "invert_01_result.npz"
filename = "invert_01_result_py.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
def test_invert_md5_c():
"""
Test Invert cpp op with md5 check
"""
logger.info("Test Invert cpp op with md5 check")
# Generate dataset
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_invert = [C.Decode(),
C.Resize(size=[224, 224]),
C.Invert(),
F.ToTensor()]
data = ds.map(input_columns="image", operations=transforms_invert)
# Compare with expected md5 from images
filename = "invert_01_result_c.npz"
save_and_check_md5(data, filename, generate_golden=GENERATE_GOLDEN)
if __name__ == "__main__":
test_invert(plot=True)
test_invert_md5()
test_invert_py(plot=False)
test_invert_c(plot=False)
test_invert_py_c(plot=False)
test_invert_one_channel()
test_invert_md5_py()
test_invert_md5_c()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册