提交 b0da1232 编写于 作者: A Alan Kelly 提交者: TensorFlow Release Automation

Fix strided slice bug where output is always writes at least one element.

a[1:1:1] should be empty, but one element is written.

PiperOrigin-RevId: 482436216
上级 115d1cb7
...@@ -1073,7 +1073,6 @@ cc_test( ...@@ -1073,7 +1073,6 @@ cc_test(
srcs = [ srcs = [
"strided_slice_logic_test.cc", "strided_slice_logic_test.cc",
], ],
shard_count = 4,
deps = [ deps = [
":strided_slice_logic", ":strided_slice_logic",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
......
...@@ -4673,108 +4673,6 @@ inline void Slice(const tflite::SliceParams& op_params, ...@@ -4673,108 +4673,6 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer); return Slice(op_params, input_shape, output_shape, &writer);
} }
// Note: This implementation is only optimized for the case where the inner
// stride == 1.
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const RuntimeShape& unextended_output_shape,
SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
ruy::profiler::ScopeLabel label("StridedSlice");
// Note that the output_shape is not used herein.
tflite::StridedSliceParams params_copy = op_params;
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(5, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(5, unextended_output_shape);
// Reverse and pad to 5 dimensions because that is what the runtime code
// requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5);
const int start_0 = StartForAxis(params_copy, input_shape, 0);
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
const int start_1 = StartForAxis(params_copy, input_shape, 1);
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
const int start_2 = StartForAxis(params_copy, input_shape, 2);
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
const int start_3 = StartForAxis(params_copy, input_shape, 3);
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
const bool inner_stride_is_1 = params_copy.strides[4] == 1;
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
offset_0 += step_0) {
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
step_1 = params_copy.strides[1] * input_shape.Dims(2);
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
step_2 = params_copy.strides[2] * input_shape.Dims(3);
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
offset_2 += step_2) {
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
step_3 = params_copy.strides[3] * input_shape.Dims(4);
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
offset_3 += step_3) {
// When the stride is 1, the inner loop is equivalent to the
// optimized slice inner loop. Otherwise, it is identical to the
// strided_slice reference implementation inner loop.
if (inner_stride_is_1) {
const int len = stop_4 - start_4;
if (len > 0) {
writer->WriteN(offset_3 + start_4, len);
}
} else {
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
offset_4 += params_copy.strides[4]) {
writer->Write(offset_4);
}
}
}
}
}
}
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const TfLiteTensor* input,
const RuntimeShape& unextended_output_shape,
TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T> template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data, void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape, const T* input2_data, const RuntimeShape& output_shape,
......
...@@ -31,10 +31,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, ...@@ -31,10 +31,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape, const RuntimeShape& unextended_input_shape,
const RuntimeShape& unextended_output_shape, const RuntimeShape& unextended_output_shape,
SequentialTensorWriter<T>* writer) { SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
ruy::profiler::ScopeLabel label("StridedSlice"); ruy::profiler::ScopeLabel label("StridedSlice");
// Note that the output_shape is not used herein. // Note that the output_shape is not used herein.
...@@ -51,41 +47,67 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params, ...@@ -51,41 +47,67 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
// requires (ie. all shapes must be 5D and are given backwards). // requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5); strided_slice::StridedSlicePadIndices(&params_copy, 5);
const int start_0 = StartForAxis(params_copy, input_shape, 0); const int start_0 =
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0); strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 0);
const int start_1 = StartForAxis(params_copy, input_shape, 1); const int stop_0 = strided_slice::StridedSliceEndForAxis(
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1); params_copy, input_shape, 0, start_0);
const int start_2 = StartForAxis(params_copy, input_shape, 2); const int start_1 =
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2); strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 1);
const int start_3 = StartForAxis(params_copy, input_shape, 3); const int stop_1 = strided_slice::StridedSliceEndForAxis(
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3); params_copy, input_shape, 1, start_1);
const int start_4 = StartForAxis(params_copy, input_shape, 4); const int start_2 =
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4); strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 2);
const int stop_2 = strided_slice::StridedSliceEndForAxis(
for (int offset_0 = start_0 * input_shape.Dims(1), params_copy, input_shape, 2, start_2);
end_0 = stop_0 * input_shape.Dims(1), const int start_3 =
step_0 = params_copy.strides[0] * input_shape.Dims(1); strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 3);
!LoopCondition(offset_0, end_0, params_copy.strides[0]); const int stop_3 = strided_slice::StridedSliceEndForAxis(
offset_0 += step_0) { params_copy, input_shape, 3, start_3);
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2), const int start_4 =
end_1 = (offset_0 + stop_1) * input_shape.Dims(2), strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 4);
step_1 = params_copy.strides[1] * input_shape.Dims(2); const int stop_4 = strided_slice::StridedSliceEndForAxis(
!LoopCondition(offset_1, end_1, params_copy.strides[1]); params_copy, input_shape, 4, start_4);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3), auto lc = [&](int end, int stride, int index) {
end_2 = (offset_1 + stop_2) * input_shape.Dims(3), if (stride < 0) {
step_2 = params_copy.strides[2] * input_shape.Dims(3); return index > end;
!LoopCondition(offset_2, end_2, params_copy.strides[2]); } else {
offset_2 += step_2) { return index < end;
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4), }
end_3 = (offset_2 + stop_3) * input_shape.Dims(4), };
step_3 = params_copy.strides[3] * input_shape.Dims(4); const int* shape = input_shape.DimsData();
!LoopCondition(offset_3, end_3, params_copy.strides[3]); const int* stride = params_copy.strides;
offset_3 += step_3) { const bool inner_stride_is_1 = params_copy.strides[4] == 1;
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]); for (int offset_0 = start_0; lc(stop_0, stride[0], offset_0);
offset_4 += params_copy.strides[4]) { offset_0 += stride[0]) {
writer->Write(offset_4); for (int offset_1 = start_1; lc(stop_1, stride[1], offset_1);
offset_1 += stride[1]) {
for (int offset_2 = start_2; lc(stop_2, stride[2], offset_2);
offset_2 += stride[2]) {
for (int offset_3 = start_3; lc(stop_3, stride[3], offset_3);
offset_3 += stride[3]) {
// When the stride is 1, the inner loop is equivalent to the
// optimized slice inner loop. Otherwise, it is identical to the
// strided_slice reference implementation inner loop.
if (inner_stride_is_1) {
const int len = stop_4 - start_4;
int index = start_4 + offset_3 * shape[4] +
offset_2 * shape[3] * shape[4] +
offset_1 * shape[2] * shape[3] * shape[4] +
offset_0 * shape[1] * shape[2] * shape[3] * shape[4];
if (len > 0) {
writer->WriteN(index, len);
}
} else {
for (int offset_4 = start_4; lc(stop_4, stride[4], offset_4);
offset_4 += stride[4]) {
int index = offset_4 + offset_3 * shape[4] +
offset_2 * shape[3] * shape[4] +
offset_1 * shape[2] * shape[3] * shape[4] +
offset_0 * shape[1] * shape[2] * shape[3] * shape[4];
writer->Write(index);
}
} }
} }
} }
......
...@@ -69,6 +69,69 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, ...@@ -69,6 +69,69 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
p->strides_count = dim_count; p->strides_count = dim_count;
} }
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
// that can be used to index directly into the data.
inline int StridedSliceStartForAxis(const tflite::StridedSliceParams& params,
const RuntimeShape& input_shape,
int32_t axis) {
const int32_t axis_size = input_shape.Dims(axis);
int32_t start = params.start_indices[axis];
const int32_t stride = params.strides[axis];
const int32_t begin_mask = (params.begin_mask & 1 << axis);
if (start < 0) {
start += axis_size;
}
if (stride > 0) {
start = Clamp(start, 0, axis_size);
} else {
start = Clamp(start, -1, axis_size - 1);
}
if (begin_mask) {
if (stride > 0) {
start = 0;
} else {
start = axis_size - 1;
}
}
return start;
}
inline int StridedSliceEndForAxis(const tflite::StridedSliceParams& params,
const RuntimeShape& input_shape, int axis,
int start) {
const auto shrink_axis_mask = params.shrink_axis_mask;
const bool shrink_axis = shrink_axis_mask & (1 << axis);
const int axis_size = input_shape.Dims(axis);
if (shrink_axis) {
if (start >= axis_size) {
return start;
} else {
return start + 1;
}
}
const auto* indices = params.stop_indices;
int end = indices[axis];
const int32_t stride = params.strides[axis];
const int32_t end_mask = (params.end_mask & 1 << axis);
if (end < 0) {
end += axis_size;
}
if (stride > 0) {
end = Clamp(end, 0, axis_size);
} else {
end = Clamp(end, -1, axis_size - 1);
}
if (end_mask) {
if (stride > 0) {
end = axis_size;
} else {
end = -1;
}
}
return end;
}
// Return the index for the first element along that axis. This index will be a // Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0) // positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
// that can be used to index directly into the data. // that can be used to index directly into the data.
......
...@@ -78,5 +78,119 @@ TEST(RunStridedSlicePadIndices, Pad3) { ...@@ -78,5 +78,119 @@ TEST(RunStridedSlicePadIndices, Pad3) {
); );
} }
TEST(StridedSliceStartForAxis, NegativeOOBIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -11;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, NegativeOneTheBoundaryIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -10;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, NegativeWithinBoundsIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -9;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 1);
}
TEST(StridedSliceStartForAxis, MinusOneIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -1;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 9);
}
TEST(StridedSliceStartForAxis, ZeroIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 0;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, OneIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 1;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 1);
}
TEST(StridedSliceStartForAxis, PositiveBoundaryIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 9;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 9);
}
TEST(StridedSliceStartForAxis, PositiveOOBIndexSizeofArray) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 10;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 10);
}
TEST(StridedSliceStartForAxis, PositiveOOBIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 11;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 10);
}
TEST(StridedSliceStartForAxis, TenFourMinus1) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 5;
params.stop_indices[0] = 2;
params.strides[0] = -1;
int start = strided_slice::StridedSliceStartForAxis(params, RuntimeShape({4}),
/*axis=*/0);
int stop = strided_slice::StridedSliceEndForAxis(params, RuntimeShape({4}),
/*axis=*/0, start);
EXPECT_EQ(start, 3);
EXPECT_EQ(stop, 2);
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite
...@@ -24,7 +24,6 @@ limitations under the License. ...@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
...@@ -70,7 +69,7 @@ struct StridedSliceContext { ...@@ -70,7 +69,7 @@ struct StridedSliceContext {
}; };
StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) { StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
StridedSliceParams op_params; StridedSliceParams op_params{};
// The ellipsis_mask and new_axis_mask in op_params are not used. Those masks // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks
// are processed here to update begin_mask, end_mask and the index range. // are processed here to update begin_mask, end_mask and the index range.
...@@ -196,9 +195,9 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, ...@@ -196,9 +195,9 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
int32_t stride = op_params.strides[idx]; int32_t stride = op_params.strides[idx];
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
int32_t begin = ::tflite::strided_slice::StartForAxis( int32_t begin = ::tflite::strided_slice::StridedSliceStartForAxis(
op_params, effective_input_shape, idx); op_params, effective_input_shape, idx);
int32_t end = ::tflite::strided_slice::StopForAxis( int32_t end = ::tflite::strided_slice::StridedSliceEndForAxis(
op_params, effective_input_shape, idx, begin); op_params, effective_input_shape, idx, begin);
// When shrinking an axis, the end position does not matter (and can be // When shrinking an axis, the end position does not matter (and can be
...@@ -272,43 +271,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { ...@@ -272,43 +271,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} }
StridedSliceParams op_params = BuildStridedSliceParams(&op_context); StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
#define TF_LITE_STRIDED_SLICE(data_type) \
{ \
if (kernel_type == kGenericOptimized) { \
optimized_ops::StridedSlice<data_type>( \
op_params, op_context.effective_input_shape, op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} else { \
reference_ops::StridedSlice<data_type>( \
op_params, op_context.effective_input_shape, op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} \
}
switch (op_context.input->type) { switch (op_context.input->type) {
case kTfLiteFloat32: case kTfLiteFloat32:
TF_LITE_STRIDED_SLICE(float); reference_ops::StridedSlice<float>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteInt32: case kTfLiteInt32:
TF_LITE_STRIDED_SLICE(int32_t); reference_ops::StridedSlice<int32_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteInt64: case kTfLiteInt64:
TF_LITE_STRIDED_SLICE(int64_t); reference_ops::StridedSlice<int64_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteUInt8: case kTfLiteUInt8:
TF_LITE_STRIDED_SLICE(uint8_t); reference_ops::StridedSlice<uint8_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteInt8: case kTfLiteInt8:
TF_LITE_STRIDED_SLICE(int8_t); reference_ops::StridedSlice<int8_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteInt16: case kTfLiteInt16:
TF_LITE_STRIDED_SLICE(int16_t); reference_ops::StridedSlice<int16_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteBool: case kTfLiteBool:
TF_LITE_STRIDED_SLICE(bool); reference_ops::StridedSlice<bool>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
case kTfLiteString: case kTfLiteString:
TF_LITE_STRIDED_SLICE(string); reference_ops::StridedSlice<string>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break; break;
default: default:
TF_LITE_KERNEL_LOG(context, TF_LITE_KERNEL_LOG(context,
......
...@@ -27,6 +27,7 @@ namespace tflite { ...@@ -27,6 +27,7 @@ namespace tflite {
namespace { namespace {
using ::testing::ElementsAreArray; using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
template <typename input_type> template <typename input_type>
class StridedSliceOpModel : public SingleOpModel { class StridedSliceOpModel : public SingleOpModel {
...@@ -36,7 +37,7 @@ class StridedSliceOpModel : public SingleOpModel { ...@@ -36,7 +37,7 @@ class StridedSliceOpModel : public SingleOpModel {
std::initializer_list<int> end_shape, std::initializer_list<int> end_shape,
std::initializer_list<int> strides_shape, int begin_mask, std::initializer_list<int> strides_shape, int begin_mask,
int end_mask, int ellipsis_mask, int new_axis_mask, int end_mask, int ellipsis_mask, int new_axis_mask,
int shrink_axis_mask) { int shrink_axis_mask, bool use_simple_allocator = true) {
input_ = AddInput(GetTensorType<input_type>()); input_ = AddInput(GetTensorType<input_type>());
begin_ = AddInput(TensorType_INT32); begin_ = AddInput(TensorType_INT32);
end_ = AddInput(TensorType_INT32); end_ = AddInput(TensorType_INT32);
...@@ -47,7 +48,8 @@ class StridedSliceOpModel : public SingleOpModel { ...@@ -47,7 +48,8 @@ class StridedSliceOpModel : public SingleOpModel {
CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
new_axis_mask, shrink_axis_mask) new_axis_mask, shrink_axis_mask)
.Union()); .Union());
BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape},
use_simple_allocator);
} }
void SetInput(std::initializer_list<input_type> data) { void SetInput(std::initializer_list<input_type> data) {
...@@ -670,7 +672,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) { ...@@ -670,7 +672,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
} }
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) { TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBeginEndMask) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0); StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetInput({1, 2}); m.SetInput({1, 2});
m.SetBegin({1}); m.SetBegin({1});
...@@ -680,6 +682,16 @@ TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) { ...@@ -680,6 +682,16 @@ TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2})); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
} }
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({1});
m.SetEnd({0});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TYPED_TEST(StridedSliceOpTest, In3D_Backward) { TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0); StridedSliceOpModel<TypeParam> m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0);
m.SetInput({1, 2}); m.SetInput({1, 2});
...@@ -854,5 +866,86 @@ TYPED_TEST(StridedSliceOpTest, NoInfiniteLoop) { ...@@ -854,5 +866,86 @@ TYPED_TEST(StridedSliceOpTest, NoInfiniteLoop) {
ASSERT_EQ(m.Invoke(), kTfLiteOk); ASSERT_EQ(m.Invoke(), kTfLiteOk);
} }
TYPED_TEST(StridedSliceOpTest, MinusThreeMinusFourMinusOne) {
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({-4});
m.SetStrides({-1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
TYPED_TEST(StridedSliceOpTest, MinusFourMinusThreeOne) {
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-4});
m.SetEnd({-3});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOne) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({2});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxis) {
StridedSliceOpModel<TypeParam> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxisOOB) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({2});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, OutOfBounds) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetBegin({1});
m.SetEnd({2});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, StrideOutOfBounds) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetBegin({1});
m.SetEnd({4});
m.SetStrides({7});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, NegEndMask) {
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0b10, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, -1});
m.SetEnd({2, -3});
m.SetStrides({1, -1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1, 6, 5, 4}));
}
} // namespace } // namespace
} // namespace tflite } // namespace tflite
...@@ -178,7 +178,8 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes, ...@@ -178,7 +178,8 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads, int num_threads,
bool allow_fp32_relax_to_fp16, bool allow_fp32_relax_to_fp16,
bool apply_delegate, bool apply_delegate,
bool allocate_and_delegate) { bool allocate_and_delegate,
bool use_simple_allocator) {
input_shapes_ = input_shapes; input_shapes_ = input_shapes;
allow_fp32_relax_to_fp16_ = allow_fp32_relax_to_fp16; allow_fp32_relax_to_fp16_ = allow_fp32_relax_to_fp16;
apply_delegate_ = apply_delegate; apply_delegate_ = apply_delegate;
...@@ -203,7 +204,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes, ...@@ -203,7 +204,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
uint8_t* buffer_pointer = builder_.GetBufferPointer(); uint8_t* buffer_pointer = builder_.GetBufferPointer();
UpdateOpVersion(buffer_pointer); UpdateOpVersion(buffer_pointer);
bool use_simple_allocator = use_simple_allocator |=
tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>( tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
tflite::KernelTestDelegateProviders::kUseSimpleAllocator); tflite::KernelTestDelegateProviders::kUseSimpleAllocator);
...@@ -288,11 +289,12 @@ TfLiteStatus SingleOpModel::ApplyDelegate() { ...@@ -288,11 +289,12 @@ TfLiteStatus SingleOpModel::ApplyDelegate() {
TfLiteStatus SingleOpModel::Invoke() { return interpreter_->Invoke(); } TfLiteStatus SingleOpModel::Invoke() { return interpreter_->Invoke(); }
void SingleOpModel::BuildInterpreter( void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
std::vector<std::vector<int>> input_shapes) { bool use_simple_allocator) {
BuildInterpreter(input_shapes, /*num_threads=*/-1, BuildInterpreter(input_shapes, /*num_threads=*/-1,
/*allow_fp32_relax_to_fp16=*/false, /*allow_fp32_relax_to_fp16=*/false,
/*apply_delegate=*/true, /*allocate_and_delegate=*/true); /*apply_delegate=*/true, /*allocate_and_delegate=*/true,
use_simple_allocator);
} }
// static // static
......
...@@ -538,9 +538,11 @@ class SingleOpModel { ...@@ -538,9 +538,11 @@ class SingleOpModel {
// `apply_delegate` is ignored. // `apply_delegate` is ignored.
void BuildInterpreter(std::vector<std::vector<int>> input_shapes, void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads, bool allow_fp32_relax_to_fp16, int num_threads, bool allow_fp32_relax_to_fp16,
bool apply_delegate, bool allocate_and_delegate = true); bool apply_delegate, bool allocate_and_delegate = true,
bool use_simple_allocator = false);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes); void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool use_simple_allocator = false);
// Executes inference and return status code. // Executes inference and return status code.
TfLiteStatus Invoke(); TfLiteStatus Invoke();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册