提交 b0da1232 编写于 作者: A Alan Kelly 提交者: TensorFlow Release Automation

Fix strided slice bug where output is always writes at least one element.

a[1:1:1] should be empty, but one element is written.

PiperOrigin-RevId: 482436216
上级 115d1cb7
......@@ -1073,7 +1073,6 @@ cc_test(
srcs = [
"strided_slice_logic_test.cc",
],
shard_count = 4,
deps = [
":strided_slice_logic",
"@com_google_googletest//:gtest_main",
......
......@@ -4673,108 +4673,6 @@ inline void Slice(const tflite::SliceParams& op_params,
return Slice(op_params, input_shape, output_shape, &writer);
}
// Note: This implementation is only optimized for the case where the inner
// stride == 1.
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const RuntimeShape& unextended_output_shape,
SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
ruy::profiler::ScopeLabel label("StridedSlice");
// Note that the output_shape is not used herein.
tflite::StridedSliceParams params_copy = op_params;
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 5);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 5);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(5, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(5, unextended_output_shape);
// Reverse and pad to 5 dimensions because that is what the runtime code
// requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5);
const int start_0 = StartForAxis(params_copy, input_shape, 0);
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
const int start_1 = StartForAxis(params_copy, input_shape, 1);
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
const int start_2 = StartForAxis(params_copy, input_shape, 2);
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
const int start_3 = StartForAxis(params_copy, input_shape, 3);
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
const bool inner_stride_is_1 = params_copy.strides[4] == 1;
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
offset_0 += step_0) {
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
step_1 = params_copy.strides[1] * input_shape.Dims(2);
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
step_2 = params_copy.strides[2] * input_shape.Dims(3);
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
offset_2 += step_2) {
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
step_3 = params_copy.strides[3] * input_shape.Dims(4);
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
offset_3 += step_3) {
// When the stride is 1, the inner loop is equivalent to the
// optimized slice inner loop. Otherwise, it is identical to the
// strided_slice reference implementation inner loop.
if (inner_stride_is_1) {
const int len = stop_4 - start_4;
if (len > 0) {
writer->WriteN(offset_3 + start_4, len);
}
} else {
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
offset_4 += params_copy.strides[4]) {
writer->Write(offset_4);
}
}
}
}
}
}
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
SequentialTensorWriter<T> writer(input_data, output_data);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const TfLiteTensor* input,
const RuntimeShape& unextended_output_shape,
TfLiteTensor* output) {
SequentialTensorWriter<T> writer(input, output);
StridedSlice<T>(op_params, unextended_input_shape, unextended_output_shape,
&writer);
}
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
......
......@@ -31,10 +31,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
const RuntimeShape& unextended_input_shape,
const RuntimeShape& unextended_output_shape,
SequentialTensorWriter<T>* writer) {
using strided_slice::LoopCondition;
using strided_slice::StartForAxis;
using strided_slice::StopForAxis;
ruy::profiler::ScopeLabel label("StridedSlice");
// Note that the output_shape is not used herein.
......@@ -51,41 +47,67 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
// requires (ie. all shapes must be 5D and are given backwards).
strided_slice::StridedSlicePadIndices(&params_copy, 5);
const int start_0 = StartForAxis(params_copy, input_shape, 0);
const int stop_0 = StopForAxis(params_copy, input_shape, 0, start_0);
const int start_1 = StartForAxis(params_copy, input_shape, 1);
const int stop_1 = StopForAxis(params_copy, input_shape, 1, start_1);
const int start_2 = StartForAxis(params_copy, input_shape, 2);
const int stop_2 = StopForAxis(params_copy, input_shape, 2, start_2);
const int start_3 = StartForAxis(params_copy, input_shape, 3);
const int stop_3 = StopForAxis(params_copy, input_shape, 3, start_3);
const int start_4 = StartForAxis(params_copy, input_shape, 4);
const int stop_4 = StopForAxis(params_copy, input_shape, 4, start_4);
for (int offset_0 = start_0 * input_shape.Dims(1),
end_0 = stop_0 * input_shape.Dims(1),
step_0 = params_copy.strides[0] * input_shape.Dims(1);
!LoopCondition(offset_0, end_0, params_copy.strides[0]);
offset_0 += step_0) {
for (int offset_1 = (offset_0 + start_1) * input_shape.Dims(2),
end_1 = (offset_0 + stop_1) * input_shape.Dims(2),
step_1 = params_copy.strides[1] * input_shape.Dims(2);
!LoopCondition(offset_1, end_1, params_copy.strides[1]);
offset_1 += step_1) {
for (int offset_2 = (offset_1 + start_2) * input_shape.Dims(3),
end_2 = (offset_1 + stop_2) * input_shape.Dims(3),
step_2 = params_copy.strides[2] * input_shape.Dims(3);
!LoopCondition(offset_2, end_2, params_copy.strides[2]);
offset_2 += step_2) {
for (int offset_3 = (offset_2 + start_3) * input_shape.Dims(4),
end_3 = (offset_2 + stop_3) * input_shape.Dims(4),
step_3 = params_copy.strides[3] * input_shape.Dims(4);
!LoopCondition(offset_3, end_3, params_copy.strides[3]);
offset_3 += step_3) {
for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
!LoopCondition(offset_4, end_4, params_copy.strides[4]);
offset_4 += params_copy.strides[4]) {
writer->Write(offset_4);
const int start_0 =
strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 0);
const int stop_0 = strided_slice::StridedSliceEndForAxis(
params_copy, input_shape, 0, start_0);
const int start_1 =
strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 1);
const int stop_1 = strided_slice::StridedSliceEndForAxis(
params_copy, input_shape, 1, start_1);
const int start_2 =
strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 2);
const int stop_2 = strided_slice::StridedSliceEndForAxis(
params_copy, input_shape, 2, start_2);
const int start_3 =
strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 3);
const int stop_3 = strided_slice::StridedSliceEndForAxis(
params_copy, input_shape, 3, start_3);
const int start_4 =
strided_slice::StridedSliceStartForAxis(params_copy, input_shape, 4);
const int stop_4 = strided_slice::StridedSliceEndForAxis(
params_copy, input_shape, 4, start_4);
auto lc = [&](int end, int stride, int index) {
if (stride < 0) {
return index > end;
} else {
return index < end;
}
};
const int* shape = input_shape.DimsData();
const int* stride = params_copy.strides;
const bool inner_stride_is_1 = params_copy.strides[4] == 1;
for (int offset_0 = start_0; lc(stop_0, stride[0], offset_0);
offset_0 += stride[0]) {
for (int offset_1 = start_1; lc(stop_1, stride[1], offset_1);
offset_1 += stride[1]) {
for (int offset_2 = start_2; lc(stop_2, stride[2], offset_2);
offset_2 += stride[2]) {
for (int offset_3 = start_3; lc(stop_3, stride[3], offset_3);
offset_3 += stride[3]) {
// When the stride is 1, the inner loop is equivalent to the
// optimized slice inner loop. Otherwise, it is identical to the
// strided_slice reference implementation inner loop.
if (inner_stride_is_1) {
const int len = stop_4 - start_4;
int index = start_4 + offset_3 * shape[4] +
offset_2 * shape[3] * shape[4] +
offset_1 * shape[2] * shape[3] * shape[4] +
offset_0 * shape[1] * shape[2] * shape[3] * shape[4];
if (len > 0) {
writer->WriteN(index, len);
}
} else {
for (int offset_4 = start_4; lc(stop_4, stride[4], offset_4);
offset_4 += stride[4]) {
int index = offset_4 + offset_3 * shape[4] +
offset_2 * shape[3] * shape[4] +
offset_1 * shape[2] * shape[3] * shape[4] +
offset_0 * shape[1] * shape[2] * shape[3] * shape[4];
writer->Write(index);
}
}
}
}
......
......@@ -69,6 +69,69 @@ inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
p->strides_count = dim_count;
}
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
// that can be used to index directly into the data.
inline int StridedSliceStartForAxis(const tflite::StridedSliceParams& params,
const RuntimeShape& input_shape,
int32_t axis) {
const int32_t axis_size = input_shape.Dims(axis);
int32_t start = params.start_indices[axis];
const int32_t stride = params.strides[axis];
const int32_t begin_mask = (params.begin_mask & 1 << axis);
if (start < 0) {
start += axis_size;
}
if (stride > 0) {
start = Clamp(start, 0, axis_size);
} else {
start = Clamp(start, -1, axis_size - 1);
}
if (begin_mask) {
if (stride > 0) {
start = 0;
} else {
start = axis_size - 1;
}
}
return start;
}
inline int StridedSliceEndForAxis(const tflite::StridedSliceParams& params,
const RuntimeShape& input_shape, int axis,
int start) {
const auto shrink_axis_mask = params.shrink_axis_mask;
const bool shrink_axis = shrink_axis_mask & (1 << axis);
const int axis_size = input_shape.Dims(axis);
if (shrink_axis) {
if (start >= axis_size) {
return start;
} else {
return start + 1;
}
}
const auto* indices = params.stop_indices;
int end = indices[axis];
const int32_t stride = params.strides[axis];
const int32_t end_mask = (params.end_mask & 1 << axis);
if (end < 0) {
end += axis_size;
}
if (stride > 0) {
end = Clamp(end, 0, axis_size);
} else {
end = Clamp(end, -1, axis_size - 1);
}
if (end_mask) {
if (stride > 0) {
end = axis_size;
} else {
end = -1;
}
}
return end;
}
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
// that can be used to index directly into the data.
......
......@@ -78,5 +78,119 @@ TEST(RunStridedSlicePadIndices, Pad3) {
);
}
TEST(StridedSliceStartForAxis, NegativeOOBIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -11;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, NegativeOneTheBoundaryIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -10;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, NegativeWithinBoundsIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -9;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 1);
}
TEST(StridedSliceStartForAxis, MinusOneIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = -1;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 9);
}
TEST(StridedSliceStartForAxis, ZeroIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 0;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 0);
}
TEST(StridedSliceStartForAxis, OneIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 1;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 1);
}
TEST(StridedSliceStartForAxis, PositiveBoundaryIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 9;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 9);
}
TEST(StridedSliceStartForAxis, PositiveOOBIndexSizeofArray) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 10;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 10);
}
TEST(StridedSliceStartForAxis, PositiveOOBIndex) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 11;
params.strides[0] = 1;
int start = strided_slice::StridedSliceStartForAxis(
params, RuntimeShape({10}), /*axis=*/0);
EXPECT_EQ(start, 10);
}
TEST(StridedSliceStartForAxis, TenFourMinus1) {
StridedSliceParams params{};
params.begin_mask = 0;
params.end_mask = 0;
params.start_indices[0] = 5;
params.stop_indices[0] = 2;
params.strides[0] = -1;
int start = strided_slice::StridedSliceStartForAxis(params, RuntimeShape({4}),
/*axis=*/0);
int stop = strided_slice::StridedSliceEndForAxis(params, RuntimeShape({4}),
/*axis=*/0, start);
EXPECT_EQ(start, 3);
EXPECT_EQ(stop, 2);
}
} // namespace
} // namespace tflite
......@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
......@@ -70,7 +69,7 @@ struct StridedSliceContext {
};
StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
StridedSliceParams op_params;
StridedSliceParams op_params{};
// The ellipsis_mask and new_axis_mask in op_params are not used. Those masks
// are processed here to update begin_mask, end_mask and the index range.
......@@ -196,9 +195,9 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
int32_t stride = op_params.strides[idx];
TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
int32_t begin = ::tflite::strided_slice::StartForAxis(
int32_t begin = ::tflite::strided_slice::StridedSliceStartForAxis(
op_params, effective_input_shape, idx);
int32_t end = ::tflite::strided_slice::StopForAxis(
int32_t end = ::tflite::strided_slice::StridedSliceEndForAxis(
op_params, effective_input_shape, idx, begin);
// When shrinking an axis, the end position does not matter (and can be
......@@ -272,43 +271,46 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
#define TF_LITE_STRIDED_SLICE(data_type) \
{ \
if (kernel_type == kGenericOptimized) { \
optimized_ops::StridedSlice<data_type>( \
op_params, op_context.effective_input_shape, op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} else { \
reference_ops::StridedSlice<data_type>( \
op_params, op_context.effective_input_shape, op_context.input, \
GetTensorShape(op_context.output), op_context.output); \
} \
}
switch (op_context.input->type) {
case kTfLiteFloat32:
TF_LITE_STRIDED_SLICE(float);
reference_ops::StridedSlice<float>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteInt32:
TF_LITE_STRIDED_SLICE(int32_t);
reference_ops::StridedSlice<int32_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteInt64:
TF_LITE_STRIDED_SLICE(int64_t);
reference_ops::StridedSlice<int64_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteUInt8:
TF_LITE_STRIDED_SLICE(uint8_t);
reference_ops::StridedSlice<uint8_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteInt8:
TF_LITE_STRIDED_SLICE(int8_t);
reference_ops::StridedSlice<int8_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteInt16:
TF_LITE_STRIDED_SLICE(int16_t);
reference_ops::StridedSlice<int16_t>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteBool:
TF_LITE_STRIDED_SLICE(bool);
reference_ops::StridedSlice<bool>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
case kTfLiteString:
TF_LITE_STRIDED_SLICE(string);
reference_ops::StridedSlice<string>(
op_params, op_context.effective_input_shape, op_context.input,
GetTensorShape(op_context.output), op_context.output);
break;
default:
TF_LITE_KERNEL_LOG(context,
......
......@@ -27,6 +27,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
template <typename input_type>
class StridedSliceOpModel : public SingleOpModel {
......@@ -36,7 +37,7 @@ class StridedSliceOpModel : public SingleOpModel {
std::initializer_list<int> end_shape,
std::initializer_list<int> strides_shape, int begin_mask,
int end_mask, int ellipsis_mask, int new_axis_mask,
int shrink_axis_mask) {
int shrink_axis_mask, bool use_simple_allocator = true) {
input_ = AddInput(GetTensorType<input_type>());
begin_ = AddInput(TensorType_INT32);
end_ = AddInput(TensorType_INT32);
......@@ -47,7 +48,8 @@ class StridedSliceOpModel : public SingleOpModel {
CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask,
new_axis_mask, shrink_axis_mask)
.Union());
BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape});
BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape},
use_simple_allocator);
}
void SetInput(std::initializer_list<input_type> data) {
......@@ -670,7 +672,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBeginEndMask) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 1, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({1});
......@@ -680,6 +682,16 @@ TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBegin) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2});
m.SetBegin({1});
m.SetEnd({0});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0, 1, 2}));
}
TYPED_TEST(StridedSliceOpTest, In3D_Backward) {
StridedSliceOpModel<TypeParam> m({1, 1, 2}, {3}, {3}, {3}, 6, 7, 0, 0, 0);
m.SetInput({1, 2});
......@@ -854,5 +866,86 @@ TYPED_TEST(StridedSliceOpTest, NoInfiniteLoop) {
ASSERT_EQ(m.Invoke(), kTfLiteOk);
}
TYPED_TEST(StridedSliceOpTest, MinusThreeMinusFourMinusOne) {
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-3});
m.SetEnd({-4});
m.SetStrides({-1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
TYPED_TEST(StridedSliceOpTest, MinusFourMinusThreeOne) {
StridedSliceOpModel<TypeParam> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({1, 2, 3, 4});
m.SetBegin({-4});
m.SetEnd({-3});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOne) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 0);
m.SetInput({2});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxis) {
StridedSliceOpModel<TypeParam> m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxisOOB) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetInput({2});
m.SetBegin({1});
m.SetEnd({1});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, OutOfBounds) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetBegin({1});
m.SetEnd({2});
m.SetStrides({1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, StrideOutOfBounds) {
StridedSliceOpModel<TypeParam> m({1}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
m.SetBegin({1});
m.SetEnd({4});
m.SetStrides({7});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), IsEmpty());
}
TYPED_TEST(StridedSliceOpTest, NegEndMask) {
StridedSliceOpModel<TypeParam> m({2, 3}, {2}, {2}, {2}, 0, 0b10, 0, 0, 0);
m.SetInput({1, 2, 3, 4, 5, 6});
m.SetBegin({0, -1});
m.SetEnd({2, -3});
m.SetStrides({1, -1});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1, 6, 5, 4}));
}
} // namespace
} // namespace tflite
......@@ -178,7 +178,8 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads,
bool allow_fp32_relax_to_fp16,
bool apply_delegate,
bool allocate_and_delegate) {
bool allocate_and_delegate,
bool use_simple_allocator) {
input_shapes_ = input_shapes;
allow_fp32_relax_to_fp16_ = allow_fp32_relax_to_fp16;
apply_delegate_ = apply_delegate;
......@@ -203,7 +204,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
uint8_t* buffer_pointer = builder_.GetBufferPointer();
UpdateOpVersion(buffer_pointer);
bool use_simple_allocator =
use_simple_allocator |=
tflite::KernelTestDelegateProviders::Get()->ConstParams().Get<bool>(
tflite::KernelTestDelegateProviders::kUseSimpleAllocator);
......@@ -288,11 +289,12 @@ TfLiteStatus SingleOpModel::ApplyDelegate() {
TfLiteStatus SingleOpModel::Invoke() { return interpreter_->Invoke(); }
void SingleOpModel::BuildInterpreter(
std::vector<std::vector<int>> input_shapes) {
void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool use_simple_allocator) {
BuildInterpreter(input_shapes, /*num_threads=*/-1,
/*allow_fp32_relax_to_fp16=*/false,
/*apply_delegate=*/true, /*allocate_and_delegate=*/true);
/*apply_delegate=*/true, /*allocate_and_delegate=*/true,
use_simple_allocator);
}
// static
......
......@@ -538,9 +538,11 @@ class SingleOpModel {
// `apply_delegate` is ignored.
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
int num_threads, bool allow_fp32_relax_to_fp16,
bool apply_delegate, bool allocate_and_delegate = true);
bool apply_delegate, bool allocate_and_delegate = true,
bool use_simple_allocator = false);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
bool use_simple_allocator = false);
// Executes inference and return status code.
TfLiteStatus Invoke();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册