提交 c65a5777 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!631 Fixed the bug on uniform augment input validation

Merge pull request !631 from AdelShafiei/uniform_augment_ut
...@@ -42,34 +42,28 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input, ...@@ -42,34 +42,28 @@ Status UniformAugOp::Compute(const std::vector<std::shared_ptr<Tensor>> &input,
std::vector<std::shared_ptr<Tensor>> *output) { std::vector<std::shared_ptr<Tensor>> *output) {
IO_CHECK_VECTOR(input, output); IO_CHECK_VECTOR(input, output);
// variables to generate random number to select ops from the list
std::vector<int> random_indexes;
// variables to copy the result to output if it is not already // variables to copy the result to output if it is not already
std::vector<std::shared_ptr<Tensor>> even_out; std::vector<std::shared_ptr<Tensor>> even_out;
std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out; std::vector<std::shared_ptr<Tensor>> *even_out_ptr = &even_out;
int count = 1; int count = 1;
// select random indexes for candidates to be applied // randomly select ops to be applied
for (int i = 0; i < num_ops_; ++i) { std::vector<std::shared_ptr<TensorOp>> selected_tensor_ops;
random_indexes.insert(random_indexes.end(), std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_);
std::uniform_int_distribution<int>(0, tensor_op_list_.size() - 1)(rnd_));
}
for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) { for (auto tensor_op = selected_tensor_ops.begin(); tensor_op != selected_tensor_ops.end(); ++tensor_op) {
// Do NOT apply the op, if second random generator returned zero // Do NOT apply the op, if second random generator returned zero
if (std::uniform_int_distribution<int>(0, 1)(rnd_)) { if (std::uniform_int_distribution<int>(0, 1)(rnd_)) {
continue; continue;
} }
std::shared_ptr<TensorOp> tensor_op = tensor_op_list_[*it];
// apply python/C++ op // apply python/C++ op
if (count == 1) { if (count == 1) {
(*tensor_op).Compute(input, output); (**tensor_op).Compute(input, output);
} else if (count % 2 == 0) { } else if (count % 2 == 0) {
(*tensor_op).Compute(*output, even_out_ptr); (**tensor_op).Compute(*output, even_out_ptr);
} else { } else {
(*tensor_op).Compute(even_out, output); (**tensor_op).Compute(even_out, output);
} }
count++; count++;
} }
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
import numbers import numbers
from functools import wraps from functools import wraps
from mindspore._c_dataengine import TensorOp
from .utils import Inter, Border from .utils import Inter, Border
from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \ from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \
check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX
def check_inter_mode(mode): def check_inter_mode(mode):
if not isinstance(mode, Inter): if not isinstance(mode, Inter):
raise ValueError("Invalid interpolation mode.") raise ValueError("Invalid interpolation mode.")
...@@ -836,7 +837,7 @@ def check_uniform_augmentation(method): ...@@ -836,7 +837,7 @@ def check_uniform_augmentation(method):
if not isinstance(operations, list): if not isinstance(operations, list):
raise ValueError("operations is not a python list") raise ValueError("operations is not a python list")
for op in operations: for op in operations:
if not callable(op): if not callable(op) and not isinstance(op, TensorOp):
raise ValueError("non-callable op in operations list") raise ValueError("non-callable op in operations list")
kwargs["num_ops"] = num_ops kwargs["num_ops"] = num_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册