diff --git a/mace/ops/one_hot.cc b/mace/ops/one_hot.cc index 38f42893a7c482407a2d87f514b17ae35b98b5a4..2077c6861ebffbecbf0f84572221ce0370db1a0c 100644 --- a/mace/ops/one_hot.cc +++ b/mace/ops/one_hot.cc @@ -57,7 +57,7 @@ class OneHotOp : public OneHotOpBase { const std::vector &input_shape = input->shape(); std::vector output_shape(input_shape.size() + 1); - MACE_CHECK(input->dim_size() < 100); // prevents too deep recursion + MACE_CHECK(input->dim_size() < 100); // prevents too deep recursion MACE_CHECK(axis >= 0 && axis <= input->dim_size()); for (size_t in = 0, out = 0; out < output_shape.size(); ++out) { @@ -98,34 +98,33 @@ class OneHotOp : public OneHotOpBase { } } } else { - run(input, input_ptr, output_ptr, axis, 0, 0, input_shape.size(), 0); + run(input, &input_ptr, &output_ptr, axis, 0, 0, input_shape.size(), 0); } return MaceStatus::MACE_SUCCESS; } private: - void run(const Tensor *input, const T *&input_ptr, - T *&output_ptr, const index_t axis, + void run(const Tensor *input, const T **input_ptr, + T **output_ptr, const index_t axis, const index_t current_in, const index_t current_out, const index_t left, const index_t test) const { - if (current_out == axis) { const index_t length = depth_; if (left == 0) { for (index_t i = 0; i < length; ++i) { - *output_ptr = *input_ptr == i ? on_value_ : off_value_; - ++output_ptr; + **output_ptr = **input_ptr == i ? on_value_ : off_value_; + ++(*output_ptr); } - ++input_ptr; + ++(*input_ptr); } else { - const T *in = input_ptr; + const T *in = *input_ptr; for (index_t i = 0; i < length; ++i) { - input_ptr = in; + *input_ptr = in; run(input, input_ptr, output_ptr, axis, current_in, current_out + 1, left - 1, i); } @@ -135,9 +134,9 @@ class OneHotOp : public OneHotOpBase { if (left == 0) { for (index_t i = 0; i < length; ++i) { - *output_ptr = *input_ptr == test ? on_value_ : off_value_; - ++output_ptr; - ++input_ptr; + **output_ptr = **input_ptr == test ? on_value_ : off_value_; + ++(*output_ptr); + ++(*input_ptr); } } else { for (index_t i = 0; i < length; ++i) {