From dd3db3d812b53efbf306342f2676c653ac1cd2f3 Mon Sep 17 00:00:00 2001 From: rsun-bdti <69477250+rsun-bdti@users.noreply.github.com> Date: Mon, 24 May 2021 15:59:47 -0700 Subject: [PATCH] Port TFL kernel Gather to TFL Micro (#97) * Port TFL kernel Gather to TFL Micro * Remove TODO comment in micro/kernel/gather.cc Co-authored-by: Advait Jain --- tensorflow/lite/micro/kernels/BUILD | 16 + tensorflow/lite/micro/kernels/gather.cc | 250 +++++----- tensorflow/lite/micro/kernels/gather_test.cc | 465 ++++++++++++++++++ tensorflow/lite/micro/kernels/micro_ops.h | 1 + .../lite/micro/micro_mutable_op_resolver.h | 5 + tensorflow/lite/micro/tools/make/Makefile | 2 + 6 files changed, 619 insertions(+), 120 deletions(-) create mode 100644 tensorflow/lite/micro/kernels/gather_test.cc diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index befc8d8f..e24f4ca1 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -287,6 +287,7 @@ cc_library( "floor.cc", "floor_div.cc", "floor_mod.cc", + "gather.cc", "gather_nd.cc", "if.cc", "l2norm.cc", @@ -753,6 +754,21 @@ cc_test( ], ) +cc_test( + name = "gather_test", + srcs = [ + "gather_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:micro_utils", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + cc_test( name = "gather_nd_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/gather.cc b/tensorflow/lite/micro/kernels/gather.cc index 22020e55..db050626 100644 --- a/tensorflow/lite/micro/kernels/gather.cc +++ b/tensorflow/lite/micro/kernels/gather.cc @@ -12,26 +12,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" -#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" -#include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/string_util.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_utils.h" namespace tflite { -namespace ops { -namespace builtin { -namespace gather { +namespace { + constexpr int kInputTensor = 0; constexpr int kInputPositions = 1; constexpr int kOutputTensor = 0; +template +TfLiteStatus Gather(const TfLiteGatherParams* params, + const TfLiteEvalTensor* input, + const TfLiteEvalTensor* coords, TfLiteEvalTensor* output) { + const InputT* input_data = tflite::micro::GetTensorData(input); + const CoordsT* coords_data = tflite::micro::GetTensorData(coords); + InputT* output_data = tflite::micro::GetTensorData(output); + const TfLiteIntArray* input_dims = input->dims; + const int input_dims_size = input_dims->size; + int axis = params->axis; + if (axis < 0) { + axis += input_dims_size; + } + TFLITE_DCHECK_GE(axis, 0); + TFLITE_DCHECK_LT(axis, input_dims_size); + + int batch_dims = params->batch_dims; + // batch_dims should be in range: [-rank(coords), rank(coords)]. + // Negative batch_dims is added with rank of coords. + const TfLiteIntArray* coords_dims = coords->dims; + const int coords_dims_size = coords_dims->size; + if (batch_dims < 0) { + batch_dims += coords_dims_size; + } + TFLITE_DCHECK_GE(batch_dims, 0); + TFLITE_DCHECK_LT(batch_dims, input_dims_size); + TFLITE_DCHECK_LE(batch_dims, coords_dims_size); + TFLITE_DCHECK_GE(axis, batch_dims); + for (int i = 0; i < batch_dims; ++i) { + TFLITE_DCHECK_EQ(input_dims->data[i], coords_dims->data[i]); + } + + const int axis_size = input_dims->data[axis]; + + int batch_size = 1; + for (int i = 0; i < batch_dims; ++i) { + batch_size *= input_dims->data[i]; + } + int outer_size = 1; + for (int i = batch_dims; i < axis; ++i) { + outer_size *= input_dims->data[i]; + } + int inner_size = 1; + for (int i = axis + 1; i < input_dims_size; ++i) { + inner_size *= input_dims->data[i]; + } + int coord_size = 1; + for (int i = batch_dims; i < coords_dims_size; ++i) { + coord_size *= coords_dims->data[i]; + } + + for (int batch = 0; batch < batch_size; ++batch) { + for (int outer = 0; outer < outer_size; ++outer) { + for (int coord = 0; coord < coord_size; ++coord) { + TFLITE_DCHECK_GE(coords_data[coord], 0); + TFLITE_DCHECK_LT(coords_data[coord], axis_size); + std::memcpy(output_data + + (((batch * outer_size) + outer) * coord_size + coord) * + inner_size, + input_data + (((batch * outer_size) + outer) * axis_size + + coords_data[batch * coord_size + coord]) * + inner_size, + sizeof(InputT) * inner_size); + } + } + } + return kTfLiteOk; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -40,22 +104,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { reinterpret_cast(node->builtin_data); const TfLiteTensor* input; TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); - const TfLiteTensor* positions; + const TfLiteTensor* coords; TF_LITE_ENSURE_OK(context, - GetInputSafe(context, node, kInputPositions, &positions)); + GetInputSafe(context, node, kInputPositions, &coords)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputTensor, &output)); - - switch (positions->type) { - case kTfLiteInt64: + switch (coords->type) { case kTfLiteInt32: break; default: TF_LITE_KERNEL_LOG(context, "Positions of type '%s' are not supported by gather.", - TfLiteTypeGetName(positions->type)); + TfLiteTypeGetName(coords->type)); return kTfLiteError; + break; } // Assign to output the input type. @@ -64,21 +127,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check conditions for different types. switch (input->type) { case kTfLiteFloat32: - case kTfLiteUInt8: case kTfLiteInt8: - case kTfLiteInt16: - case kTfLiteInt64: - case kTfLiteInt32: - case kTfLiteBool: break; - case kTfLiteString: { - // Only 1D input is supported. - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); - } break; default: TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", TfLiteTypeGetName(input->type)); return kTfLiteError; + break; } int axis = params->axis; @@ -87,126 +142,81 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); - const int num_dimensions = - NumDimensions(input) + NumDimensions(positions) - 1; - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int batch_dims = params->batch_dims; + // batch_dims should be in range: [-rank(coords), rank(coords)]. + // Negative batch_dims is added with rank of coords. + if (batch_dims < 0) { + batch_dims += NumDimensions(coords); + } + TF_LITE_ENSURE(context, batch_dims <= axis); + TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input)); + TF_LITE_ENSURE(context, batch_dims <= NumDimensions(coords)); + for (int i = 0; i < batch_dims; ++i) { + TF_LITE_ENSURE_EQ(context, input->dims->data[i], coords->dims->data[i]); + } + + // GATHER updates the output tensor dimensions, but TfLiteTensor in the + // MicroInterpreter is a temporary allocation. We must therefore relocate the + // dims from the FlatBuffer to the persistant storage arena. + TfLiteEvalTensor* output_eval = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy( + context, output, output_eval)); + + TfLiteIntArray* output_shape = output->dims; + output_shape->size = + NumDimensions(input) + NumDimensions(coords) - 1 - batch_dims; int output_index = 0; for (int i = 0; i < axis; ++i) { output_shape->data[output_index++] = input->dims->data[i]; } - for (int i = 0; i < positions->dims->size; ++i) { - output_shape->data[output_index++] = positions->dims->data[i]; + for (int i = batch_dims; i < coords->dims->size; ++i) { + output_shape->data[output_index++] = coords->dims->data[i]; } for (int i = axis + 1; i < input->dims->size; ++i) { output_shape->data[output_index++] = input->dims->data[i]; } - return context->ResizeTensor(context, output, output_shape); -} - -template -TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input, - const TfLiteTensor* positions, TfLiteTensor* output) { - tflite::GatherParams op_params; - op_params.axis = params.axis; - optimized_ops::Gather(op_params, GetTensorShape(input), - GetTensorData(input), GetTensorShape(positions), - GetTensorData(positions), - GetTensorShape(output), GetTensorData(output)); - return kTfLiteOk; -} - -template -TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input, - const TfLiteTensor* positions, - TfLiteTensor* output) { - DynamicBuffer buffer; - const PositionT* indexes = GetTensorData(positions); - const PositionT num_strings = GetStringCount(input); - const int num_indexes = NumElements(positions); - - for (int i = 0; i < num_indexes; ++i) { - const PositionT pos = indexes[i]; - TF_LITE_ENSURE(context, pos < num_strings); - const auto string_ref = GetString(input, pos); - buffer.AddString(string_ref.str, string_ref.len); - } - buffer.WriteToTensor(output, /*new_shape=*/nullptr); return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input; - TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); - const TfLiteTensor* positions; - TF_LITE_ENSURE_OK(context, - GetInputSafe(context, node, kInputPositions, &positions)); - TfLiteTensor* output; - TF_LITE_ENSURE_OK(context, - GetOutputSafe(context, node, kOutputTensor, &output)); - - if (positions->type == kTfLiteInt32) { - switch (input->type) { - case kTfLiteFloat32: - return Gather(*params, input, positions, output); - case kTfLiteUInt8: - return Gather(*params, input, positions, output); - case kTfLiteInt8: - return Gather(*params, input, positions, output); - case kTfLiteInt16: - return Gather(*params, input, positions, output); - case kTfLiteInt32: - return Gather(*params, input, positions, output); - case kTfLiteInt64: - return Gather(*params, input, positions, output); - case kTfLiteBool: - return Gather(*params, input, positions, output); - case kTfLiteString: - return GatherStrings(context, input, positions, output); - default: - TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", - TfLiteTypeGetName(input->type)); - return kTfLiteError; - } - } - if (positions->type == kTfLiteInt64) { + const TfLiteEvalTensor* input = + tflite::micro::GetEvalInput(context, node, kInputTensor); + const TfLiteEvalTensor* coords = + tflite::micro::GetEvalInput(context, node, kInputPositions); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + if (coords->type == kTfLiteInt32) { switch (input->type) { case kTfLiteFloat32: - return Gather(*params, input, positions, output); - case kTfLiteUInt8: - return Gather(*params, input, positions, output); + return Gather(params, input, coords, output); + break; case kTfLiteInt8: - return Gather(*params, input, positions, output); - case kTfLiteInt16: - return Gather(*params, input, positions, output); - case kTfLiteInt32: - return Gather(*params, input, positions, output); - case kTfLiteInt64: - return Gather(*params, input, positions, output); - case kTfLiteBool: - return Gather(*params, input, positions, output); - case kTfLiteString: - return GatherStrings(context, input, positions, output); + return Gather(params, input, coords, output); + break; default: TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.", TfLiteTypeGetName(input->type)); return kTfLiteError; + break; } } - TF_LITE_KERNEL_LOG(context, - "Positions of type '%s' are not supported by gather.", - TfLiteTypeGetName(positions->type)); - return kTfLiteError; + return kTfLiteOk; } -} // namespace gather - -TfLiteRegistration* Register_GATHER() { - static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, - gather::Eval}; - return &r; +} // namespace + +TfLiteRegistration Register_GATHER() { + return {/*init=*/nullptr, + /*free=*/nullptr, + /*prepare=*/Prepare, + /*invoke=*/Eval, + /*profiling_string=*/nullptr, + /*builtin_code=*/0, + /*custom_name=*/nullptr, + /*version=*/0}; } -} // namespace builtin -} // namespace ops } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/gather_test.cc b/tensorflow/lite/micro/kernels/gather_test.cc new file mode 100644 index 00000000..c5df0c87 --- /dev/null +++ b/tensorflow/lite/micro/kernels/gather_test.cc @@ -0,0 +1,465 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/all_ops_resolver.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +template +void TestGather(int* input_dims, const InType* input_data, int* positions_dims, + const PosType* positions_data, int* output_dims, + InType* output_data, const int* expected_output_dims, + const InType* expected_output_data, const int axis = 0, + const int batch_dims = 0) { + TfLiteIntArray* in_dims = IntArrayFromInts(input_dims); + TfLiteIntArray* pos_dims = IntArrayFromInts(positions_dims); + TfLiteIntArray* out_dims = IntArrayFromInts(output_dims); + TfLiteGatherParams params = {axis, batch_dims}; + + constexpr int inputs_size = 2; + constexpr int outputs_size = 1; + constexpr int tensors_size = inputs_size + outputs_size; + TfLiteTensor tensors[tensors_size] = { + CreateTensor(input_data, in_dims), + CreateTensor(positions_data, pos_dims), + CreateTensor(output_data, out_dims, true), + }; + int inputs_array_data[] = {2, 0, 1}; + TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); + int outputs_array_data[] = {1, 2}; + TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); + + const TfLiteRegistration registration = Register_GATHER(); + micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, + outputs_array, ¶ms); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); + + // The output tensor's data and shape have been updated by the kernel. + TfLiteTensor* actual_output_tensor = &tensors[2]; + TfLiteIntArray* actual_output_dims = actual_output_tensor->dims; + const int actual_output_dims_size = actual_output_dims->size; + const int output_size = ElementCount(*actual_output_dims); + for (int i = 0; i < output_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } + + // Compare output tensor's shape if expected_output_dims[] is provided. + for (int i = 0; i < actual_output_dims_size; ++i) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_dims[i], + actual_output_dims->data[i]); + } +} + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +// For all test functions below, dims[0] is the dimension count. +TF_LITE_MICRO_TEST(GatherOp_Shuffle) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {2, 2, 2}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {1, 0}; + const float input_data[] = {-2.0, 0.2, 0.7, 0.8}; + const float golden_data[] = {0.7, 0.8, -2, 0.2}; + float output_data[4]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Test0DIndex) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {2, 2, 2}; + int positions_dims[] = {0}; + const int32_t positions_data[] = {1}; + const float input_data[] = {-2.0, 0.2, 0.7, 0.8}; + const float golden_data[] = {0.7, 0.8}; + float output_data[2]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2}; + int output_dims[] = {1, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Test0DIndexWith0DResult) { + // 0D tensor is special case in current TFLite. Test it once to make sure + // existing workarounds are fine with it. + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {1, 3}; + int positions_dims[] = {0}; + const int32_t positions_data[] = {1}; + const float input_data[] = {1.0, 2.0, 3.0}; + const float golden_data[] = {2.0}; + float output_data[1]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {0}; + int output_dims[] = {1, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Test1DInput1DIndex) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {1, 3}; + int positions_dims[] = {1, 1}; + const int32_t positions_data[] = {1}; + const float input_data[] = {1.0, 3.0, 5.0}; + const float golden_data[] = {3.0}; + float output_data[1]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1}; + int output_dims[] = {1, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Test2DIndexWith2DResult) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {1, 3}; + int positions_dims[] = {2, 1, 2}; + const int32_t positions_data[] = {1, 0}; + const float input_data[] = {1.0, 2.0, 3.0}; + const float golden_data[] = {2.0, 1.0}; + float output_data[2]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Duplicate) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {3, 1, 2, 2}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {0, 0}; + const float input_data[] = {-2.0, 0.2, 0.7, 0.8}; + const float golden_data[] = {-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8}; + float output_data[8]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2, 2}; + int output_dims[] = {3, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Slice) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {2, 4, 1}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {1, 3}; + const float input_data[] = {-2.0, 0.2, 0.7, 0.8}; + const float golden_data[] = {0.2, 0.8}; + float output_data[2]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 1}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Axis1) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 1; + int input_dims[] = {3, 1, 2, 3}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {1, 0}; + const float input_data[] = {1, 2, 3, 4, 5, 6}; + const float golden_data[] = {4, 5, 6, 1, 2, 3}; + float output_data[6]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2, 3}; + int output_dims[] = {3, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis); +} + +TF_LITE_MICRO_TEST(GatherOp_Axis1_0DIndex) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 1; + int input_dims[] = {3, 1, 3, 2}; + int positions_dims[] = {0}; + const int32_t positions_data[] = {1}; + const float input_data[] = {1, 2, 3, 4, 5, 6}; + const float golden_data[] = {3, 4}; + float output_data[2]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis); +} + +TF_LITE_MICRO_TEST(GatherOp_Axis1Slice) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 1; + int input_dims[] = {3, 1, 4, 2}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {3, 1}; + const float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + const float golden_data[] = {7, 8, 3, 4}; + float output_data[4]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2, 2}; + int output_dims[] = {3, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis); +} + +TF_LITE_MICRO_TEST(GatherOp_LastAxis) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = -1; + int input_dims[] = {3, 1, 2, 3}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {2, 0}; + const float input_data[] = {1, 2, 3, 4, 5, 6}; + const float golden_data[] = {3, 1, 6, 4}; + float output_data[4]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2, 2}; + int output_dims[] = {3, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis); +} + +TF_LITE_MICRO_TEST(GatherOp_LastAxis0DIndex) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = -1; + int input_dims[] = {3, 1, 2, 3}; + int positions_dims[] = {0}; + const int32_t positions_data[] = {2}; + const float input_data[] = {1, 2, 3, 4, 5, 6}; + const float golden_data[] = {3, 6}; + float output_data[2]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {1, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis); +} + +TF_LITE_MICRO_TEST(GatherOp_Float32Int32) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {2, 2, 2}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {1, 0}; + const float input_data[] = {13.3, -13.4, -1.4, 1.5}; + const float golden_data[] = {-1.4, 1.5, 13.3, -13.4}; + float output_data[4]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_Int8Int32) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + int input_dims[] = {2, 2, 2}; + int positions_dims[] = {1, 2}; + const int32_t positions_data[] = {1, 0}; + const int8_t input_data[] = {-13, -120, 14, 15}; + const int8_t golden_data[] = {14, 15, -13, -120}; + int8_t output_data[4]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2}; + int output_dims[] = {2, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data); +} + +TF_LITE_MICRO_TEST(GatherOp_BatchDims2) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 2; + const int batch_dims = 2; + int input_dims[] = {4, 2, 2, 3, 5}; + int positions_dims[] = {3, 2, 2, 2}; + const int32_t positions_data[] = {1, 0, 0, 1, 1, 0, 0, 1}; + const float input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}; + const float golden_data[] = {5, 6, 7, 8, 9, 0, 1, 2, 3, 4, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 35, 36, 37, 38, 39, 30, 31, 32, 33, 34, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}; + float output_data[40]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2, 2, 5}; + int output_dims[] = {4, 0, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis, batch_dims); +} + +TF_LITE_MICRO_TEST(GatherOp_BatchDims1) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 2; + const int batch_dims = 1; + int input_dims[] = {4, 2, 2, 3, 5}; + int positions_dims[] = {3, 2, 2, 2}; + const int32_t positions_data[] = {1, 0, 0, 1, 1, 0, 0, 1}; + const int8_t input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}; + const int8_t golden_data[] = { + 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 20, 21, 22, 23, 24, 15, 16, 17, 18, 19, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 35, 36, 37, 38, 39, 30, 31, 32, + 33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53, + 54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}; + int8_t output_data[80]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2, 2, 2, 5}; + int output_dims[] = {5, 0, 0, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis, batch_dims); +} + +TF_LITE_MICRO_TEST(GatherOp_NegativeBatchDims) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 2; + const int batch_dims = -2; + int input_dims[] = {4, 2, 2, 3, 5}; + int positions_dims[] = {3, 2, 2, 2}; + const int32_t positions_data[] = {1, 0, 0, 1, 1, 0, 0, 1}; + const int8_t input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}; + const int8_t golden_data[] = { + 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 20, 21, 22, 23, 24, 15, 16, 17, 18, 19, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 35, 36, 37, 38, 39, 30, 31, 32, + 33, 34, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 50, 51, 52, 53, + 54, 45, 46, 47, 48, 49, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}; + int8_t output_data[80]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2, 2, 2, 5}; + int output_dims[] = {5, 0, 0, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis, batch_dims); +} + +TF_LITE_MICRO_TEST(GatherOp_BatchDimsEqualIndiceDims) { + // For input_dims[], positions_dims[], or output_dims[], element 0 is the + // number of dimensions in that array, not the actual dimension data. + const int axis = 3; + const int batch_dims = 3; + int input_dims[] = {4, 2, 2, 2, 5}; + int positions_dims[] = {3, 2, 2, 2}; + const int32_t positions_data[] = {1, 0, 0, 1, 1, 0, 0, 1}; + const int8_t input_data[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}; + const int8_t golden_data[] = {1, 5, 10, 16, 21, 25, 30, 36}; + int8_t output_data[8]; + + // The kernel under test will fill output_dims[1] onward, to be compared + // against golden_dims[0] onward. + const int golden_dims[] = {2, 2, 2}; + int output_dims[] = {3, 0, 0, 0}; + tflite::testing::TestGather( + input_dims, input_data, positions_dims, positions_data, output_dims, + output_data, golden_dims, golden_data, axis, batch_dims); +} + +TF_LITE_MICRO_TESTS_END diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index 3a83145c..46e49370 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -45,6 +45,7 @@ TfLiteRegistration Register_EXPAND_DIMS(); TfLiteRegistration Register_FILL(); TfLiteRegistration Register_FLOOR_DIV(); TfLiteRegistration Register_FLOOR_MOD(); +TfLiteRegistration Register_GATHER(); TfLiteRegistration Register_GATHER_ND(); TfLiteRegistration Register_IF(); TfLiteRegistration Register_L2_POOL_2D(); diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index d2ee66d7..8fb0e8b5 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -255,6 +255,11 @@ class MicroMutableOpResolver : public MicroOpResolver { ParseFullyConnected); } + TfLiteStatus AddGather() { + return AddBuiltin(BuiltinOperator_GATHER, tflite::Register_GATHER(), + ParseGather); + } + TfLiteStatus AddGatherNd() { return AddBuiltin(BuiltinOperator_GATHER_ND, tflite::Register_GATHER_ND(), ParseGatherNd); diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 40b84557..7bb4880c 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -282,6 +282,7 @@ tensorflow/lite/micro/kernels/floor_test.cc \ tensorflow/lite/micro/kernels/floor_div_test.cc \ tensorflow/lite/micro/kernels/floor_mod_test.cc \ tensorflow/lite/micro/kernels/fully_connected_test.cc \ +tensorflow/lite/micro/kernels/gather_test.cc \ tensorflow/lite/micro/kernels/gather_nd_test.cc \ tensorflow/lite/micro/kernels/hard_swish_test.cc \ tensorflow/lite/micro/kernels/l2norm_test.cc \ @@ -351,6 +352,7 @@ tensorflow/lite/micro/kernels/floor_div.cc \ tensorflow/lite/micro/kernels/floor_mod.cc \ tensorflow/lite/micro/kernels/fully_connected.cc \ tensorflow/lite/micro/kernels/fully_connected_common.cc \ +tensorflow/lite/micro/kernels/gather.cc \ tensorflow/lite/micro/kernels/gather_nd.cc \ tensorflow/lite/micro/kernels/hard_swish.cc \ tensorflow/lite/micro/kernels/kernel_runner.cc \ -- GitLab