...
 
Commits (4)
    https://gitcode.net/xusiwei1236/tflite-micro/-/commit/9dc65ef8ae8c2ea38f56e3a74ccf4356f6a8df58 Port EMBEDDING_LOOKUP to TFLM (#1872) 2023-05-29T01:25:28+00:00 David Davis ddavis-2015@users.noreply.github.com @tensorflow/micro Port the EMBEDDING_LOOKUP kernel operator to TFLM. Recreate tests for kernel operator (INT8, FLOAT32). Add SymmetricScaleFromMinMax method to test_helpers.h file. Update Makefiles. Update Bazel build script. bug=fixes #1871 https://gitcode.net/xusiwei1236/tflite-micro/-/commit/2b7f86c86ef15b6363efabf3e7163b3d478ea8b3 Remove dead code from Dequantize (#2011) 2023-05-31T02:01:17+00:00 RJ Ascani rjascani@google.com The Dequantize op only supports an output type of Float32. This change removes output_multiplier adjustments for Int32 outputs, as it cannot be reached due to an earlier check. It also eliminates logging for unsupported output types in DequantizeEval, as those would also be caught by an earlier check. BUG=<a href="http://b/230890286" rel="nofollow noreferrer noopener" target="_blank">http://b/230890286</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/78b4040c8a09798abedd557208b83f330217824d Fix SELECT_V2 Prepare deallocation error (#2013) 2023-06-01T03:28:11+00:00 David Davis ddavis-2015@users.noreply.github.com @tensorflow/micro Fix logic error in Prepare that prevented deallocation of temporary tensors, when scalars were present in the inputs. Added 3 tests from TfLite for scalar inputs. bug=fixes #2010 https://gitcode.net/xusiwei1236/tflite-micro/-/commit/2b7a4ad621d0ff95db639bc98e39fb47d0363c21 Update SELECT_V2 to match TFLM kernel code (#2015) 2023-06-01T18:42:25+00:00 David Davis ddavis-2015@users.noreply.github.com @tensorflow/micro Eval method now uses TfLiteEvalTensor. The `has_low_rank_input_condition` variable has been removed from OpData as SELECT_V2 does not use this variable. All code outside of the registration has been placed into the anonymous namespace. bug=fixes #2014
......@@ -207,6 +207,7 @@ tflm_kernel_cc_library(
"div.cc",
"elementwise.cc",
"elu.cc",
"embedding_lookup.cc",
"ethosu.cc",
"exp.cc",
"expand_dims.cc",
......@@ -672,6 +673,21 @@ cc_test(
],
)
cc_test(
name = "embedding_lookup_test",
srcs = [
"embedding_lookup_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
......
......@@ -69,6 +69,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/exp_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/expand_dims_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/fill_test.cc \
......
......@@ -41,40 +41,36 @@ TfLiteStatus DequantizeEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
if (output->type == kTfLiteFloat32) {
switch (input->type) {
case kTfLiteInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteInt16:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
} else {
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
// Output type ensured to be kTfLiteFloat32 at the Prepare stage
TFLITE_DCHECK(output->type == kTfLiteFloat32);
switch (input->type) {
case kTfLiteInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteInt16:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteUInt8:
reference_ops::Dequantize(data->quantization_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<uint8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
default:
MicroPrintf("Input %s, output %s not supported.",
TfLiteTypeGetName(input->type),
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
......
......@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
#include "tensorflow/lite/kernels/internal/reference/quantize.h"
#include "tensorflow/lite/kernels/internal/reference/requantize.h"
......@@ -46,14 +45,6 @@ TfLiteStatus DequantizePrepare(TfLiteContext* context, TfLiteNode* node) {
input->type == kTfLiteUInt8);
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
if (output->type == kTfLiteInt32) {
const double effective_output_scale =
static_cast<double>(input->params.scale) /
static_cast<double>(output->params.scale);
QuantizeMultiplier(effective_output_scale, &data->output_multiplier,
&data->output_shift);
}
data->quantization_params.zero_point = input->params.zero_point;
data->quantization_params.scale = static_cast<double>(input->params.scale);
data->output_zero_point = output->params.zero_point;
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Ops that looks up items from matrix.
//
// Input:
// Tensor[0]: Row numbers to lookup, dim.size == 1, int32
// Tensor[1]: 2-dimensional matrix of multi-dimensional items
// dim.size >= 2, all items are INT8 or FLOAT32.
// first dimension is row, second dimension is column.
//
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
// Each item in output is a raw bytes copy of the corresponding item in input,
// or a dequantized value in the case of a INT8 input.
// When indices are out of bound, the ops will not succeed.
//
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor_0 = 0;
constexpr int kInputTensor_1 = 1;
constexpr int kOutputTensor = 0;
struct OpData {
float scale; // quantization scale for tensor 1
size_t num_columns; // number of columns after flattening tensor 1 into 2D
};
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* tensor_1,
const TfLiteTensor* output) {
node->user_data = context->AllocatePersistentBuffer(context, sizeof(OpData));
OpData* op_data = static_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, op_data != nullptr);
if (tensor_1->type == kTfLiteInt8 && output->type == kTfLiteFloat32) {
TF_LITE_ENSURE_EQ(context, tensor_1->params.zero_point, 0);
op_data->scale = tensor_1->params.scale;
}
op_data->num_columns = NumElements(tensor_1) / tensor_1->dims->data[0];
return kTfLiteOk;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* lookup =
micro_context->AllocateTempInputTensor(node, kInputTensor_0);
TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
TfLiteTensor* value =
micro_context->AllocateTempInputTensor(node, kInputTensor_1);
TF_LITE_ENSURE(context, value != nullptr);
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
TF_LITE_ENSURE(context,
value->type == kTfLiteFloat32 || value->type == kTfLiteInt8);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (value->type == kTfLiteFloat32) {
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
} else {
TF_LITE_ENSURE(
context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt8);
}
// make sure output dimensions size can hold the new dimension data
TF_LITE_ENSURE(context, output->dims->size >= NumDimensions(value));
// make the output tensor dimensions mutable
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
// set the new output dimensions
output->dims->data[0] = SizeOfDimension(lookup, 0);
output->dims->data[1] = SizeOfDimension(value, 1);
for (int i = 2; i < NumDimensions(value); i++) {
output->dims->data[i] = SizeOfDimension(value, i);
}
// check the new output dimensions do not exceed the output data buffer size
size_t new_dims_size = NumElements(output) * TfLiteTypeGetSize(output->type);
TF_LITE_ENSURE(context, new_dims_size <= output->bytes);
TF_LITE_ENSURE_OK(context, CalculateOpData(context, node, value, output));
micro_context->DeallocateTempTfLiteTensor(lookup);
micro_context->DeallocateTempTfLiteTensor(value);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus EvalSimple(const OpData& op_data, const TfLiteEvalTensor* lookup,
const TfLiteEvalTensor* value,
TfLiteEvalTensor* output) {
const int num_rows = value->dims->data[0];
if (num_rows == 0) {
// Propagate empty tensor if input is empty
return kTfLiteOk;
}
const size_t row_bytes = op_data.num_columns * TfLiteTypeGetSize(value->type);
int8_t* output_raw = tflite::micro::GetTensorData<int8_t>(output);
const int8_t* value_raw = tflite::micro::GetTensorData<int8_t>(value);
const int32_t* lookup_data = tflite::micro::GetTensorData<int32_t>(lookup);
for (int i = 0; i < lookup->dims->data[0]; i++) {
int32_t idx = lookup_data[i];
if (idx >= num_rows || idx < 0) {
MicroPrintf(
"EMBEDDING_LOOKUP: index out of bounds. "
"Got %d, and bounds are [0, %d]",
idx, num_rows - 1);
return kTfLiteError;
} else {
std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
row_bytes);
}
}
return kTfLiteOk;
}
TfLiteStatus EvalHybrid(const OpData& op_data, const TfLiteEvalTensor* lookup,
const TfLiteEvalTensor* value,
TfLiteEvalTensor* output) {
const int num_rows = value->dims->data[0];
const size_t num_colums = op_data.num_columns;
float* output_ptr = tflite::micro::GetTensorData<float>(output);
const int8_t* value_ptr = tflite::micro::GetTensorData<int8_t>(value);
const int32_t* lookup_data = tflite::micro::GetTensorData<int32_t>(lookup);
for (int i = 0; i < lookup->dims->data[0]; i++) {
int32_t idx = lookup_data[i];
if (idx >= num_rows || idx < 0) {
MicroPrintf(
"EMBEDDING_LOOKUP: index out of bounds. "
"Got %d, and bounds are [0, %d]",
idx, num_rows - 1);
return kTfLiteError;
} else {
// Dequantize embedding values.
Dequantize(&value_ptr[idx * num_colums], num_colums, op_data.scale, 0,
&output_ptr[i * num_colums]);
}
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* lookup =
tflite::micro::GetEvalInput(context, node, kInputTensor_0);
const TfLiteEvalTensor* value =
tflite::micro::GetEvalInput(context, node, kInputTensor_1);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
OpData& op_data = *static_cast<OpData*>(node->user_data);
switch (value->type) {
case kTfLiteFloat32:
return EvalSimple(op_data, lookup, value, output);
case kTfLiteInt8:
if (output->type == kTfLiteFloat32) {
return EvalHybrid(op_data, lookup, value, output);
} else {
return EvalSimple(op_data, lookup, value, output);
}
default:
MicroPrintf("EMBEDDING_LOOKUP only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
} // namespace
TFLMRegistration Register_EMBEDDING_LOOKUP() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <iterator>
#include <type_traits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
constexpr float kTestTolerance = 7.41e-03;
constexpr int kNumInputs = 2;
constexpr int kNumOutputs = 1;
constexpr int kInputTensorIndex_0 = 0;
constexpr int kInputTensorIndex_1 = 1;
constexpr int kOutputTensorIndex = 2;
// min/max are used to compute scale, zero-point is 0
template <size_t kInputSize>
struct TestEmbeddingLookupParams {
// quantization parameters
float data_min; // input data minimum value
float data_max; // input data maximum value
int8_t input_data[kInputSize]; // quantized input storage
};
void ExecuteEmbeddingLookupTest(TfLiteTensor* tensors, int tensors_count) {
int kInputArrayData[] = {kNumInputs, kInputTensorIndex_0,
kInputTensorIndex_1};
TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
const TFLMRegistration registration = tflite::Register_EMBEDDING_LOOKUP();
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
outputs_array, nullptr);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
template <size_t N>
void TestEmbeddingLookupQuantized(TestEmbeddingLookupParams<N>& params,
int* input_dims_data[kNumInputs],
const int32_t* input_data_0,
const float* input_data_1, int* expected_dims,
const float* expected_data,
float* output_data) {
TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]);
TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
const float scale =
SymmetricScaleFromMinMax<int8_t>(params.data_min, params.data_max);
TfLiteTensor tensors[] = {
CreateTensor(input_data_0, input_dims_0),
CreateQuantizedTensor(input_data_1, params.input_data, input_dims_1,
scale, 0),
CreateTensor(output_data, output_dims),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteEmbeddingLookupTest(tensors, tensors_count);
// check output data against expected
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTestTolerance);
}
// check output dimensions (relocated) against original dimensions
TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
tensors[kOutputTensorIndex].dims->size);
for (int i = 0; i < output_dims->size; i++) {
TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
tensors[kOutputTensorIndex].dims->data[i]);
}
} // namespace
template <typename T>
void TestEmbeddingLookup(int* input_dims_data[kNumInputs],
const int32_t* input_data_0, const T* input_data_1,
int* expected_dims, const T* expected_data,
T* output_data) {
TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]);
TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
TfLiteTensor tensors[] = {
CreateTensor(input_data_0, input_dims_0),
CreateTensor(input_data_1, input_dims_1),
CreateTensor(output_data, output_dims),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteEmbeddingLookupTest(tensors, tensors_count);
// check output data against expected
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTestTolerance);
}
// check output dimensions (relocated) against original dimensions
TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
tensors[kOutputTensorIndex].dims->size);
for (int i = 0; i < output_dims->size; i++) {
TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
tensors[kOutputTensorIndex].dims->data[i]);
}
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(EmbeddingLookupOpTestSimpleFloat) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr float kExpect[] = {
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookup(kInputDims, kInput_0, kInput_1,
kOutputDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple2DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {2, 3, 8};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {2, 3, 8};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple3DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple4DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {4, 3, 2, 2, 2};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {4, 3, 2, 2, 2};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(EmbeddingLookupOpTestSimpleInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr int8_t kInput_1[] = {
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
-56, -55, -54, -53, -46, -45, -44, -43, // Row 2
};
constexpr int8_t kExpect[] = {
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
-56, -55, -54, -53, -46, -45, -44, -43, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
int8_t output_data[kOutputCount];
tflite::testing::TestEmbeddingLookup(kInputDims, kInput_0, kInput_1,
kOutputDims, kExpect, output_data);
}
TF_LITE_MICRO_TESTS_END
......@@ -55,6 +55,7 @@ TFLMRegistration Register_DEPTHWISE_CONV_2D();
TFLMRegistration Register_DEQUANTIZE();
TFLMRegistration Register_DIV();
TFLMRegistration Register_ELU();
TFLMRegistration Register_EMBEDDING_LOOKUP();
TFLMRegistration Register_EQUAL();
TFLMRegistration* Register_ETHOSU();
TFLMRegistration Register_EXP();
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_log.h"
namespace tflite {
namespace {
constexpr int kInputTensorCondition = 0;
constexpr int kInputTensorX = 1;
......@@ -32,9 +33,6 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
// True if input condition is scalar or input condition has rank one and
// matches the first dimension of other inputs.
bool has_low_rank_input_condition;
};
void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
......@@ -42,7 +40,6 @@ void* SelectInit(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = static_cast<OpData*>(
context->AllocatePersistentBuffer(context, sizeof(OpData)));
data->requires_broadcast = false;
data->has_low_rank_input_condition = false;
return data;
}
......@@ -101,16 +98,15 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
// Respect the original output shape when there are mixed shapes to represent
// a scalar data.
if (GetTensorShape(input_condition).FlatSize() == 1 &&
bool possible_mixed_scaler =
GetTensorShape(input_condition).FlatSize() == 1 &&
GetTensorShape(input_x).FlatSize() == 1 &&
GetTensorShape(input_y).FlatSize() == 1 &&
GetTensorShape(output).FlatSize() == 1) {
return kTfLiteOk;
}
GetTensorShape(output).FlatSize() == 1;
bool same_shape = HaveSameShapes(input_condition, input_x) &&
HaveSameShapes(input_x, input_y);
if (!same_shape) {
if (!same_shape && !possible_mixed_scaler) {
TF_LITE_ENSURE_OK(
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
output->dims));
......@@ -125,65 +121,68 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input_condition =
micro_context->AllocateTempInputTensor(node, kInputTensorCondition);
TfLiteTensor* input_x =
micro_context->AllocateTempInputTensor(node, kInputTensorX);
TfLiteTensor* input_y =
micro_context->AllocateTempInputTensor(node, kInputTensorY);
template <typename T>
void CallSelect(const TfLiteEvalTensor* input_condition,
const TfLiteEvalTensor* input_x,
const TfLiteEvalTensor* input_y, TfLiteEvalTensor* output,
bool need_broadcast) {
using Func = decltype(reference_ops::Select<bool, T>)*;
Func select_func;
if (need_broadcast) {
select_func = reference_ops::BroadcastSelect5DSlow<bool, T>;
} else {
select_func = reference_ops::Select<bool, T>;
}
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
select_func(tflite::micro::GetTensorShape(input_condition),
tflite::micro::GetTensorData<bool>(input_condition),
tflite::micro::GetTensorShape(input_x),
tflite::micro::GetTensorData<T>(input_x),
tflite::micro::GetTensorShape(input_y),
tflite::micro::GetTensorData<T>(input_y),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<T>(output));
}
#define TF_LITE_SELECT(type, op) \
reference_ops::op(GetTensorShape(input_condition), \
GetTensorData<bool>(input_condition), \
GetTensorShape(input_x), GetTensorData<type>(input_x), \
GetTensorShape(input_y), GetTensorData<type>(input_y), \
GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
case kTfLiteFloat32: \
TF_LITE_SELECT(float, op); \
break; \
case kTfLiteInt8: \
TF_LITE_SELECT(int8_t, op); \
break; \
case kTfLiteInt16: \
TF_LITE_SELECT(int16_t, op); \
break; \
default: \
MicroPrintf("Does not support type other than %s, but got %s", \
"int8|int16|float32", TfLiteTypeGetName(type)); \
return kTfLiteError; \
}
TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = static_cast<OpData*>(node->user_data);
if (data->has_low_rank_input_condition) {
MicroPrintf("Not yet implemented.");
return kTfLiteError;
} else if (data->requires_broadcast) {
TF_LITE_SWITCH(input_x->type, BroadcastSelect5DSlow);
} else {
TF_LITE_SWITCH(input_x->type, Select);
const TfLiteEvalTensor* input_condition =
tflite::micro::GetEvalInput(context, node, kInputTensorCondition);
const TfLiteEvalTensor* input_x =
tflite::micro::GetEvalInput(context, node, kInputTensorX);
const TfLiteEvalTensor* input_y =
tflite::micro::GetEvalInput(context, node, kInputTensorY);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (input_x->type) {
case kTfLiteFloat32:
CallSelect<float>(input_condition, input_x, input_y, output,
data->requires_broadcast);
break;
case kTfLiteInt8:
CallSelect<int8_t>(input_condition, input_x, input_y, output,
data->requires_broadcast);
break;
case kTfLiteInt16:
CallSelect<int16_t>(input_condition, input_x, input_y, output,
data->requires_broadcast);
break;
default:
MicroPrintf("Does not support type other than %s, but got %s",
"int8|int16|float32", TfLiteTypeGetName(input_x->type));
return kTfLiteError;
}
#undef TF_LITE_SELECT
#undef TF_LITE_SWITCH
micro_context->DeallocateTempTfLiteTensor(input_condition);
micro_context->DeallocateTempTfLiteTensor(input_x);
micro_context->DeallocateTempTfLiteTensor(input_y);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
} // namespace
// SelectV2 op selects values of 'x' if the corresponding value of 'condition'
// is true or the value of 'y' if false. There are valid condition input sizes:
//
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <type_traits>
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
......@@ -207,4 +209,55 @@ TF_LITE_MICRO_TEST(BroadcastSelectInt16OneDimensionConditionWithTwoValues) {
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputConditionTensor) {
int input1_shape[] = {0}; // conditional data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {false};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {5};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input1_shape, input1_data, input_shape,
input2_data, input_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputXTensor) {
int input2_shape[] = {0}; // x data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {true};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {1};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input_shape, input1_data, input2_shape,
input2_data, input_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputYTensor) {
int input3_shape[] = {0}; // y data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {false};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {5};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input_shape, input1_data, input_shape,
input2_data, input3_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TESTS_END
......@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_MICRO_TEST_HELPERS_H_
#define TENSORFLOW_LITE_MICRO_TEST_HELPERS_H_
#include <algorithm>
#include <cstdint>
#include <limits>
#include <type_traits>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
......@@ -298,7 +300,7 @@ TfLiteTensor CreateSymmetricPerChannelQuantizedTensor(
// Returns the number of tensors in the default subgraph for a tflite::Model.
size_t GetModelTensorCount(const Model* model);
// Derives the quantization scaling factor from a min and max range.
// Derives the asymmetric quantization scaling factor from a min and max range.
template <typename T>
inline float ScaleFromMinMax(const float min, const float max) {
return (max - min) /
......@@ -306,6 +308,19 @@ inline float ScaleFromMinMax(const float min, const float max) {
std::numeric_limits<T>::min());
}
// Derives the symmetric quantization scaling factor from a min and max range.
template <typename T>
inline float SymmetricScaleFromMinMax(const float min, const float max) {
const int32_t kScale =
std::numeric_limits<typename std::make_signed<T>::type>::max();
const float range = std::max(std::abs(min), std::abs(max));
if (range == 0) {
return 1.0f;
} else {
return range / kScale;
}
}
// Derives the quantization zero point from a min and max range.
template <typename T>
inline int ZeroPointFromMinMax(const float min, const float max) {
......
......@@ -331,6 +331,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/detection_postprocess.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ethosu.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/exp.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/expand_dims.cc \
......