From b0da123274a3230d0f52d45fa419c611db4354f0 Mon Sep 17 00:00:00 2001 From: Alan Kelly Date: Thu, 20 Oct 2022 02:52:30 -0700 Subject: [PATCH] Fix strided slice bug where output is always writes at least one element. a[1:1:1] should be empty, but one element is written. PiperOrigin-RevId: 482436216 --- tensorflow/lite/kernels/internal/BUILD | 1 - .../internal/optimized/optimized_ops.h | 102 ---------------- .../internal/reference/strided_slice.h | 100 +++++++++------ .../kernels/internal/strided_slice_logic.h | 63 ++++++++++ .../internal/strided_slice_logic_test.cc | 114 ++++++++++++++++++ tensorflow/lite/kernels/strided_slice.cc | 52 ++++---- tensorflow/lite/kernels/strided_slice_test.cc | 99 ++++++++++++++- tensorflow/lite/kernels/test_util.cc | 12 +- tensorflow/lite/kernels/test_util.h | 6 +- 9 files changed, 372 insertions(+), 177 deletions(-) diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 095c75c626e..e10eb9e2b9d 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -1073,7 +1073,6 @@ cc_test( srcs = [ "strided_slice_logic_test.cc", ], - shard_count = 4, deps = [ ":strided_slice_logic", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index d76fa112045..154340d2c76 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -4673,108 +4673,6 @@ inline void Slice(const tflite::SliceParams& op_params, return Slice(op_params, input_shape, output_shape, &writer); } -// Note: This implementation is only optimized for the case where the inner -// stride == 1. -template -inline void StridedSlice(const tflite::StridedSliceParams& op_params, - const RuntimeShape& unextended_input_shape, - const RuntimeShape& unextended_output_shape, - SequentialTensorWriter* writer) { - using strided_slice::LoopCondition; - using strided_slice::StartForAxis; - using strided_slice::StopForAxis; - - ruy::profiler::ScopeLabel label("StridedSlice"); - - // Note that the output_shape is not used herein. - tflite::StridedSliceParams params_copy = op_params; - - TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5); - TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5); - const RuntimeShape input_shape = - RuntimeShape::ExtendedShape(5, unextended_input_shape); - const RuntimeShape output_shape = - RuntimeShape::ExtendedShape(5, unextended_output_shape); - - // Reverse and pad to 5 dimensions because that is what the runtime code - // requires (ie. all shapes must be 5D and are given backwards). - strided_slice::StridedSlicePadIndices(¶ms_copy, 5); - - const int start_0 = StartForAxis(params_copy, input_shape, 0); - const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0); - const int start_1 = StartForAxis(params_copy, input_shape, 1); - const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1); - const int start_2 = StartForAxis(params_copy, input_shape, 2); - const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2); - const int start_3 = StartForAxis(params_copy, input_shape, 3); - const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3); - const int start_4 = StartForAxis(params_copy, input_shape, 4); - const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); - const bool inner_stride_is_1 = params_copy.strides[4] == 1; - - for (int offset_0 = start_0 * input_shape.Dims(1), - end_0 = stop_0 * input_shape.Dims(1), - step_0 = params_copy.strides[0] * input_shape.Dims(1); - !LoopCondition(offset_0, end_0, params_copy.strides[0]); - offset_0 += step_0) { - for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2), - end_1 = (offset_0 + stop_1) * input_shape.Dims(2), - step_1 = params_copy.strides[1] * input_shape.Dims(2); - !LoopCondition(offset_1, end_1, params_copy.strides[1]); - offset_1 += step_1) { - for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3), - end_2 = (offset_1 + stop_2) * input_shape.Dims(3), - step_2 = params_copy.strides[2] * input_shape.Dims(3); - !LoopCondition(offset_2, end_2, params_copy.strides[2]); - offset_2 += step_2) { - for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4), - end_3 = (offset_2 + stop_3) * input_shape.Dims(4), - step_3 = params_copy.strides[3] * input_shape.Dims(4); - !LoopCondition(offset_3, end_3, params_copy.strides[3]); - offset_3 += step_3) { - // When the stride is 1, the inner loop is equivalent to the - // optimized slice inner loop. Otherwise, it is identical to the - // strided_slice reference implementation inner loop. - if (inner_stride_is_1) { - const int len = stop_4 - start_4; - if (len > 0) { - writer->WriteN(offset_3 + start_4, len); - } - } else { - for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4; - !LoopCondition(offset_4, end_4, params_copy.strides[4]); - offset_4 += params_copy.strides[4]) { - writer->Write(offset_4); - } - } - } - } - } - } -} - -template -inline void StridedSlice(const tflite::StridedSliceParams& op_params, - const RuntimeShape& unextended_input_shape, - const T* input_data, - const RuntimeShape& unextended_output_shape, - T* output_data) { - SequentialTensorWriter writer(input_data, output_data); - StridedSlice(op_params, unextended_input_shape, unextended_output_shape, - &writer); -} - -template -inline void StridedSlice(const tflite::StridedSliceParams& op_params, - const RuntimeShape& unextended_input_shape, - const TfLiteTensor* input, - const RuntimeShape& unextended_output_shape, - TfLiteTensor* output) { - SequentialTensorWriter writer(input, output); - StridedSlice(op_params, unextended_input_shape, unextended_output_shape, - &writer); -} - template void Minimum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, diff --git a/tensorflow/lite/kernels/internal/reference/strided_slice.h b/tensorflow/lite/kernels/internal/reference/strided_slice.h index 40dc2e91022..ff367cf95f1 100644 --- a/tensorflow/lite/kernels/internal/reference/strided_slice.h +++ b/tensorflow/lite/kernels/internal/reference/strided_slice.h @@ -31,10 +31,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, const RuntimeShape& unextended_input_shape, const RuntimeShape& unextended_output_shape, SequentialTensorWriter* writer) { - using strided_slice::LoopCondition; - using strided_slice::StartForAxis; - using strided_slice::StopForAxis; - ruy::profiler::ScopeLabel label("StridedSlice"); // Note that the output_shape is not used herein. @@ -51,41 +47,67 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, // requires (ie. all shapes must be 5D and are given backwards). strided_slice::StridedSlicePadIndices(¶ms_copy, 5); - const int start_0 = StartForAxis(params_copy, input_shape, 0); - const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0); - const int start_1 = StartForAxis(params_copy, input_shape, 1); - const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1); - const int start_2 = StartForAxis(params_copy, input_shape, 2); - const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2); - const int start_3 = StartForAxis(params_copy, input_shape, 3); - const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3); - const int start_4 = StartForAxis(params_copy, input_shape, 4); - const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); - - for (int offset_0 = start_0 * input_shape.Dims(1), - end_0 = stop_0 * input_shape.Dims(1), - step_0 = params_copy.strides[0] * input_shape.Dims(1); - !LoopCondition(offset_0, end_0, params_copy.strides[0]); - offset_0 += step_0) { - for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2), - end_1 = (offset_0 + stop_1) * input_shape.Dims(2), - step_1 = params_copy.strides[1] * input_shape.Dims(2); - !LoopCondition(offset_1, end_1, params_copy.strides[1]); - offset_1 += step_1) { - for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3), - end_2 = (offset_1 + stop_2) * input_shape.Dims(3), - step_2 = params_copy.strides[2] * input_shape.Dims(3); - !LoopCondition(offset_2, end_2, params_copy.strides[2]); - offset_2 += step_2) { - for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4), - end_3 = (offset_2 + stop_3) * input_shape.Dims(4), - step_3 = params_copy.strides[3] * input_shape.Dims(4); - !LoopCondition(offset_3, end_3, params_copy.strides[3]); - offset_3 += step_3) { - for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4; - !LoopCondition(offset_4, end_4, params_copy.strides[4]); - offset_4 += params_copy.strides[4]) { - writer->Write(offset_4); + const int start_0 = + strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 0); + const int stop_0 = strided_slice::StridedSliceEndForAxis( + params_copy, input_shape, 0, start_0); + const int start_1 = + strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 1); + const int stop_1 = strided_slice::StridedSliceEndForAxis( + params_copy, input_shape, 1, start_1); + const int start_2 = + strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 2); + const int stop_2 = strided_slice::StridedSliceEndForAxis( + params_copy, input_shape, 2, start_2); + const int start_3 = + strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 3); + const int stop_3 = strided_slice::StridedSliceEndForAxis( + params_copy, input_shape, 3, start_3); + const int start_4 = + strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 4); + const int stop_4 = strided_slice::StridedSliceEndForAxis( + params_copy, input_shape, 4, start_4); + + auto lc = [&](int end, int stride, int index) { + if (stride < 0) { + return index > end; + } else { + return index < end; + } + }; + const int* shape = input_shape.DimsData(); + const int* stride = params_copy.strides; + const bool inner_stride_is_1 = params_copy.strides[4] == 1; + + for (int offset_0 = start_0; lc(stop_0, stride[0], offset_0); + offset_0 += stride[0]) { + for (int offset_1 = start_1; lc(stop_1, stride[1], offset_1); + offset_1 += stride[1]) { + for (int offset_2 = start_2; lc(stop_2, stride[2], offset_2); + offset_2 += stride[2]) { + for (int offset_3 = start_3; lc(stop_3, stride[3], offset_3); + offset_3 += stride[3]) { + // When the stride is 1, the inner loop is equivalent to the + // optimized slice inner loop. Otherwise, it is identical to the + // strided_slice reference implementation inner loop. + if (inner_stride_is_1) { + const int len = stop_4 - start_4; + int index = start_4 + offset_3 * shape[4] + + offset_2 * shape[3] * shape[4] + + offset_1 * shape[2] * shape[3] * shape[4] + + offset_0 * shape[1] * shape[2] * shape[3] * shape[4]; + if (len > 0) { + writer->WriteN(index, len); + } + } else { + for (int offset_4 = start_4; lc(stop_4, stride[4], offset_4); + offset_4 += stride[4]) { + int index = offset_4 + offset_3 * shape[4] + + offset_2 * shape[3] * shape[4] + + offset_1 * shape[2] * shape[3] * shape[4] + + offset_0 * shape[1] * shape[2] * shape[3] * shape[4]; + writer->Write(index); + } } } } diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic.h b/tensorflow/lite/kernels/internal/strided_slice_logic.h index bfe84050dca..2efdcf26fe0 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic.h +++ b/tensorflow/lite/kernels/internal/strided_slice_logic.h @@ -69,6 +69,69 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, p->strides_count = dim_count; } +// Return the index for the first element along that axis. This index will be a +// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0) +// that can be used to index directly into the data. +inline int StridedSliceStartForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, + int32_t axis) { + const int32_t axis_size = input_shape.Dims(axis); + int32_t start = params.start_indices[axis]; + const int32_t stride = params.strides[axis]; + const int32_t begin_mask = (params.begin_mask & 1 << axis); + if (start < 0) { + start += axis_size; + } + if (stride > 0) { + start = Clamp(start, 0, axis_size); + } else { + start = Clamp(start, -1, axis_size - 1); + } + if (begin_mask) { + if (stride > 0) { + start = 0; + } else { + start = axis_size - 1; + } + } + return start; +} + +inline int StridedSliceEndForAxis(const tflite::StridedSliceParams& params, + const RuntimeShape& input_shape, int axis, + int start) { + const auto shrink_axis_mask = params.shrink_axis_mask; + const bool shrink_axis = shrink_axis_mask & (1 << axis); + const int axis_size = input_shape.Dims(axis); + if (shrink_axis) { + if (start >= axis_size) { + return start; + } else { + return start + 1; + } + } + const auto* indices = params.stop_indices; + int end = indices[axis]; + const int32_t stride = params.strides[axis]; + const int32_t end_mask = (params.end_mask & 1 << axis); + if (end < 0) { + end += axis_size; + } + if (stride > 0) { + end = Clamp(end, 0, axis_size); + } else { + end = Clamp(end, -1, axis_size - 1); + } + if (end_mask) { + if (stride > 0) { + end = axis_size; + } else { + end = -1; + } + } + return end; +} + // Return the index for the first element along that axis. This index will be a // positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0) // that can be used to index directly into the data. diff --git a/tensorflow/lite/kernels/internal/strided_slice_logic_test.cc b/tensorflow/lite/kernels/internal/strided_slice_logic_test.cc index b48c647278b..d4aa8b94e14 100644 --- a/tensorflow/lite/kernels/internal/strided_slice_logic_test.cc +++ b/tensorflow/lite/kernels/internal/strided_slice_logic_test.cc @@ -78,5 +78,119 @@ TEST(RunStridedSlicePadIndices, Pad3) { ); } +TEST(StridedSliceStartForAxis, NegativeOOBIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = -11; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 0); +} + +TEST(StridedSliceStartForAxis, NegativeOneTheBoundaryIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = -10; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 0); +} + +TEST(StridedSliceStartForAxis, NegativeWithinBoundsIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = -9; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 1); +} + +TEST(StridedSliceStartForAxis, MinusOneIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = -1; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 9); +} + +TEST(StridedSliceStartForAxis, ZeroIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 0; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 0); +} + +TEST(StridedSliceStartForAxis, OneIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 1; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 1); +} + +TEST(StridedSliceStartForAxis, PositiveBoundaryIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 9; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 9); +} + +TEST(StridedSliceStartForAxis, PositiveOOBIndexSizeofArray) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 10; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 10); +} + +TEST(StridedSliceStartForAxis, PositiveOOBIndex) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 11; + params.strides[0] = 1; + int start = strided_slice::StridedSliceStartForAxis( + params, RuntimeShape({10}), /*axis=*/0); + EXPECT_EQ(start, 10); +} + +TEST(StridedSliceStartForAxis, TenFourMinus1) { + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.start_indices[0] = 5; + params.stop_indices[0] = 2; + params.strides[0] = -1; + int start = strided_slice::StridedSliceStartForAxis(params, RuntimeShape({4}), + /*axis=*/0); + int stop = strided_slice::StridedSliceEndForAxis(params, RuntimeShape({4}), + /*axis=*/0, start); + EXPECT_EQ(start, 3); + EXPECT_EQ(stop, 2); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/strided_slice.cc b/tensorflow/lite/kernels/strided_slice.cc index 55aecc92765..f6f5d584610 100644 --- a/tensorflow/lite/kernels/strided_slice.cc +++ b/tensorflow/lite/kernels/strided_slice.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" @@ -70,7 +69,7 @@ struct StridedSliceContext { }; StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) { - StridedSliceParams op_params; + StridedSliceParams op_params{}; // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks // are processed here to update begin_mask, end_mask and the index range. @@ -196,9 +195,9 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, int32_t stride = op_params.strides[idx]; TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); - int32_t begin = ::tflite::strided_slice::StartForAxis( + int32_t begin = ::tflite::strided_slice::StridedSliceStartForAxis( op_params, effective_input_shape, idx); - int32_t end = ::tflite::strided_slice::StopForAxis( + int32_t end = ::tflite::strided_slice::StridedSliceEndForAxis( op_params, effective_input_shape, idx, begin); // When shrinking an axis, the end position does not matter (and can be @@ -272,43 +271,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } StridedSliceParams op_params = BuildStridedSliceParams(&op_context); -#define TF_LITE_STRIDED_SLICE(data_type) \ - { \ - if (kernel_type == kGenericOptimized) { \ - optimized_ops::StridedSlice( \ - op_params, op_context.effective_input_shape, op_context.input, \ - GetTensorShape(op_context.output), op_context.output); \ - } else { \ - reference_ops::StridedSlice( \ - op_params, op_context.effective_input_shape, op_context.input, \ - GetTensorShape(op_context.output), op_context.output); \ - } \ - } - switch (op_context.input->type) { case kTfLiteFloat32: - TF_LITE_STRIDED_SLICE(float); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteInt32: - TF_LITE_STRIDED_SLICE(int32_t); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteInt64: - TF_LITE_STRIDED_SLICE(int64_t); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteUInt8: - TF_LITE_STRIDED_SLICE(uint8_t); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteInt8: - TF_LITE_STRIDED_SLICE(int8_t); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteInt16: - TF_LITE_STRIDED_SLICE(int16_t); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteBool: - TF_LITE_STRIDED_SLICE(bool); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; case kTfLiteString: - TF_LITE_STRIDED_SLICE(string); + reference_ops::StridedSlice( + op_params, op_context.effective_input_shape, op_context.input, + GetTensorShape(op_context.output), op_context.output); break; default: TF_LITE_KERNEL_LOG(context, diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 1fe1974a1b3..41180435a2e 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -27,6 +27,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using ::testing::IsEmpty; template class StridedSliceOpModel : public SingleOpModel { @@ -36,7 +37,7 @@ class StridedSliceOpModel : public SingleOpModel { std::initializer_list end_shape, std::initializer_list strides_shape, int begin_mask, int end_mask, int ellipsis_mask, int new_axis_mask, - int shrink_axis_mask) { + int shrink_axis_mask, bool use_simple_allocator = true) { input_ = AddInput(GetTensorType()); begin_ = AddInput(TensorType_INT32); end_ = AddInput(TensorType_INT32); @@ -47,7 +48,8 @@ class StridedSliceOpModel : public SingleOpModel { CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) .Union()); - BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}, + use_simple_allocator); } void SetInput(std::initializer_list data) { @@ -670,7 +672,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) { +TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBeginEndMask) { StridedSliceOpModel m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0); m.SetInput({1, 2}); m.SetBegin({1}); @@ -680,6 +682,16 @@ TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2})); } +TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) { + StridedSliceOpModel m({1, 1, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2}); + m.SetBegin({1}); + m.SetEnd({0}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2})); +} + TYPED_TEST(StridedSliceOpTest, In3D_Backward) { StridedSliceOpModel m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0); m.SetInput({1, 2}); @@ -854,5 +866,86 @@ TYPED_TEST(StridedSliceOpTest, NoInfiniteLoop) { ASSERT_EQ(m.Invoke(), kTfLiteOk); } +TYPED_TEST(StridedSliceOpTest, MinusThreeMinusFourMinusOne) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({-4}); + m.SetStrides({-1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TYPED_TEST(StridedSliceOpTest, MinusFourMinusThreeOne) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-4}); + m.SetEnd({-3}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TYPED_TEST(StridedSliceOpTest, OneOneOne) { + StridedSliceOpModel m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({2}); + m.SetBegin({1}); + m.SetEnd({1}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); +} + +TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxis) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3}); + m.SetBegin({1}); + m.SetEnd({1}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxisOOB) { + StridedSliceOpModel m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({2}); + m.SetBegin({1}); + m.SetEnd({1}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TYPED_TEST(StridedSliceOpTest, OutOfBounds) { + StridedSliceOpModel m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetBegin({1}); + m.SetEnd({2}); + m.SetStrides({1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TYPED_TEST(StridedSliceOpTest, StrideOutOfBounds) { + StridedSliceOpModel m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetBegin({1}); + m.SetEnd({4}); + m.SetStrides({7}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); +} + +TYPED_TEST(StridedSliceOpTest, NegEndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0b10, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, -1}); + m.SetEnd({2, -3}); + m.SetStrides({1, -1}); + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1, 6, 5, 4})); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/test_util.cc b/tensorflow/lite/kernels/test_util.cc index c0242f7a911..71cb3cb1818 100644 --- a/tensorflow/lite/kernels/test_util.cc +++ b/tensorflow/lite/kernels/test_util.cc @@ -178,7 +178,8 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, int num_threads, bool allow_fp32_relax_to_fp16, bool apply_delegate, - bool allocate_and_delegate) { + bool allocate_and_delegate, + bool use_simple_allocator) { input_shapes_ = input_shapes; allow_fp32_relax_to_fp16_ = allow_fp32_relax_to_fp16; apply_delegate_ = apply_delegate; @@ -203,7 +204,7 @@ void SingleOpModel::BuildInterpreter(std::vector> input_shapes, uint8_t* buffer_pointer = builder_.GetBufferPointer(); UpdateOpVersion(buffer_pointer); - bool use_simple_allocator = + use_simple_allocator |= tflite::KernelTestDelegateProviders::Get()->ConstParams().Get( tflite::KernelTestDelegateProviders::kUseSimpleAllocator); @@ -288,11 +289,12 @@ TfLiteStatus SingleOpModel::ApplyDelegate() { TfLiteStatus SingleOpModel::Invoke() { return interpreter_->Invoke(); } -void SingleOpModel::BuildInterpreter( - std::vector> input_shapes) { +void SingleOpModel::BuildInterpreter(std::vector> input_shapes, + bool use_simple_allocator) { BuildInterpreter(input_shapes, /*num_threads=*/-1, /*allow_fp32_relax_to_fp16=*/false, - /*apply_delegate=*/true, /*allocate_and_delegate=*/true); + /*apply_delegate=*/true, /*allocate_and_delegate=*/true, + use_simple_allocator); } // static diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index 0e69c310fcc..1116dd78add 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -538,9 +538,11 @@ class SingleOpModel { // `apply_delegate` is ignored. void BuildInterpreter(std::vector> input_shapes, int num_threads, bool allow_fp32_relax_to_fp16, - bool apply_delegate, bool allocate_and_delegate = true); + bool apply_delegate, bool allocate_and_delegate = true, + bool use_simple_allocator = false); - void BuildInterpreter(std::vector> input_shapes); + void BuildInterpreter(std::vector> input_shapes, + bool use_simple_allocator = false); // Executes inference and return status code. TfLiteStatus Invoke(); -- GitLab