提交 5fcd3f01 编写于 作者: A Adel Shafiei

Added C++ UniformAugOp support

上级 822a3160
......@@ -40,6 +40,7 @@
#include "dataset/kernels/image/rescale_op.h"
#include "dataset/kernels/image/resize_bilinear_op.h"
#include "dataset/kernels/image/resize_op.h"
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/data/type_cast_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
......@@ -264,6 +265,10 @@ void bindTensorOps1(py::module *m) {
.def(py::init<int32_t, int32_t, InterpolationMode>(), py::arg("targetHeight"),
py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation);
(void)py::class_<UniformAugOp, TensorOp, std::shared_ptr<UniformAugOp>>(
*m, "UniformAugOp", "Tensor operation to apply random augmentation(s).")
.def(py::init<py::list, int32_t>(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps);
(void)py::class_<ResizeBilinearOp, TensorOp, std::shared_ptr<ResizeBilinearOp>>(
*m, "ResizeBilinearOp",
"Tensor operation to resize an image using "
......
......@@ -19,6 +19,7 @@ if (WIN32)
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
uniform_aug_op.cc
)
else()
add_library(kernels-image OBJECT
......@@ -42,5 +43,6 @@ else()
rescale_op.cc
resize_bilinear_op.cc
resize_op.cc
uniform_aug_op.cc
)
endif()
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/kernels/image/uniform_aug_op.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
namespace mindspore {
namespace dataset {
const int UniformAugOp::kDefNumOps = 2;
UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) {
std::shared_ptr<TensorOp> tensor_op;
// iterate over the op list, cast them to TensorOp and add them to tensor_op_list_
for (auto op : op_list) {
if (py::isinstance<py::function>(op)) {
// python op
tensor_op = std::make_shared<PyFuncOp>(op.cast<py::function>());
} else if (py::isinstance<TensorOp>(op)) {
// C++ op
tensor_op = op.cast<std::shared_ptr<TensorOp>>();
}
tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op);
}
rnd_.seed(GetSeed());
}
// compute method to apply uniformly random selected augmentations from a list
Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output);
// variables to generate random number to select ops from the list
std::vector<int> random_indexes;
// variables to copy the result to output if it is not already
std::vector<std::shared_ptr<Tensor>> even_out;
std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out;
int count = 1;
// select random indexes for candidates to be applied
for (int i = 0; i < num_ops_; ++i) {
random_indexes.insert(random_indexes.end(),
std::uniform_int_distribution<int>(0, tensor_op_list_.size() - 1)(rnd_));
}
for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) {
// Do NOT apply the op, if second random generator returned zero
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) {
continue;
}
std::shared_ptr<TensorOp> tensor_op = tensor_op_list_[*it];
// apply python/C++ op
if (count == 1) {
(*tensor_op).Compute(input, output);
} else if (count % 2 == 0) {
(*tensor_op).Compute(*output, even_out_ptr);
} else {
(*tensor_op).Compute(even_out, output);
}
count++;
}
// copy the result to output if it is not in output
if (count == 1) {
*output = input;
} else if ((count % 2 == 1)) {
(*output).swap(even_out);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_
#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/status.h"
#include "dataset/kernels/py_func_op.h"
#include "pybind11/stl.h"
namespace mindspore {
namespace dataset {
class UniformAugOp : public TensorOp {
public:
// Default number of Operations to be applied
static const int kDefNumOps;
// Constructor for UniformAugOp
// @param list op_list: list of candidate python operations
// @param list num_ops: number of augemtation operations to applied
UniformAugOp(py::list op_list, int32_t num_ops);
~UniformAugOp() override = default;
void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; }
// Overrides the base class compute function
// @return Status - The error code return
Status Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) override;
private:
int32_t num_ops_;
std::vector<std::shared_ptr<TensorOp>> tensor_op_list_;
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_
......@@ -45,7 +45,7 @@ import mindspore._c_dataengine as cde
from .utils import Inter, Border
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_resize, check_rescale, check_pad, check_cutout
check_resize, check_rescale, check_pad, check_cutout, check_uniform_augmentation
DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR,
Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR,
......@@ -447,3 +447,19 @@ class Pad(cde.PadOp):
fill_value = tuple([fill_value] * 3)
padding_mode = DE_C_BORDER_TYPE[padding_mode]
super().__init__(*padding, padding_mode, *fill_value)
class UniformAugment(cde.UniformAugOp):
"""
Tensor operation to perform randomly selected augmentation
Args:
operations: list of python operations.
NumOps (int): number of OPs to be selected and applied.
"""
@check_uniform_augmentation
def __init__(self, operations, num_ops=2):
self.operations = operations
self.num_ops = num_ops
super().__init__(operations, num_ops)
......@@ -812,3 +812,36 @@ def check_rescale(method):
return method(self, **kwargs)
return new_method
def check_uniform_augmentation(method):
"""Wrapper method to check the parameters of UniformAugmentation."""
@wraps(method)
def new_method(self, *args, **kwargs):
operations, num_ops = (list(args) + 2 * [None])[:2]
if "operations" in kwargs:
operations = kwargs.get("operations")
else:
raise ValueError("operations list required")
if "num_ops" in kwargs:
num_ops = kwargs.get("num_ops")
else:
num_ops = 2
if num_ops <= 0:
raise ValueError("num_ops should be greater than zero")
if num_ops > len(operations):
raise ValueError("num_ops is greater than operations list size")
if not isinstance(operations, list):
raise ValueError("operations is not a python list")
for op in operations:
if not callable(op):
raise ValueError("non-callable op in operations list")
kwargs["num_ops"] = num_ops
kwargs["operations"] = operations
return method(self, **kwargs)
return new_method
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册