diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc index 5725c10908031e3f550caaf49c822a890d806107..1214345c37014572b00e8a6848b2688d8ed1853c 100644 --- a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc @@ -42,34 +42,28 @@ Status UniformAugOp::Compute(const std::vector> &input, std::vector> *output) { IO_CHECK_VECTOR(input, output); - // variables to generate random number to select ops from the list - std::vector random_indexes; - // variables to copy the result to output if it is not already std::vector> even_out; std::vector> *even_out_ptr = &even_out; int count = 1; - // select random indexes for candidates to be applied - for (int i = 0; i < num_ops_; ++i) { - random_indexes.insert(random_indexes.end(), - std::uniform_int_distribution(0, tensor_op_list_.size() - 1)(rnd_)); - } + // randomly select ops to be applied + std::vector> selected_tensor_ops; + std::sample(tensor_op_list_.begin(), tensor_op_list_.end(), std::back_inserter(selected_tensor_ops), num_ops_, rnd_); - for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) { + for (auto tensor_op = selected_tensor_ops.begin(); tensor_op != selected_tensor_ops.end(); ++tensor_op) { // Do NOT apply the op, if second random generator returned zero if (std::uniform_int_distribution(0, 1)(rnd_)) { continue; } - std::shared_ptr tensor_op = tensor_op_list_[*it]; // apply python/C++ op if (count == 1) { - (*tensor_op).Compute(input, output); + (**tensor_op).Compute(input, output); } else if (count % 2 == 0) { - (*tensor_op).Compute(*output, even_out_ptr); + (**tensor_op).Compute(*output, even_out_ptr); } else { - (*tensor_op).Compute(even_out, output); + (**tensor_op).Compute(even_out, output); } count++; } diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index 713d9c5714111e1dbe7dad27ed409bfc398e03d0..2c299b077bdd25ca7d36585a71e3c362beb242da 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -17,11 +17,12 @@ import numbers from functools import wraps +from mindspore._c_dataengine import TensorOp + from .utils import Inter, Border from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \ check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX - def check_inter_mode(mode): if not isinstance(mode, Inter): raise ValueError("Invalid interpolation mode.") @@ -836,7 +837,7 @@ def check_uniform_augmentation(method): if not isinstance(operations, list): raise ValueError("operations is not a python list") for op in operations: - if not callable(op): + if not callable(op) and not isinstance(op, TensorOp): raise ValueError("non-callable op in operations list") kwargs["num_ops"] = num_ops