未验证 提交 9dc65ef8 编写于 作者: D David Davis 提交者: GitHub

Port EMBEDDING_LOOKUP to TFLM (#1872)

@tensorflow/micro

Port the EMBEDDING_LOOKUP kernel operator to TFLM.
Recreate tests for kernel operator (INT8, FLOAT32). 
Add SymmetricScaleFromMinMax method to test_helpers.h file. 
Update Makefiles.
Update Bazel build script.

bug=fixes #1871
上级 148af67d
......@@ -207,6 +207,7 @@ tflm_kernel_cc_library(
"div.cc",
"elementwise.cc",
"elu.cc",
"embedding_lookup.cc",
"ethosu.cc",
"exp.cc",
"expand_dims.cc",
......@@ -672,6 +673,21 @@ cc_test(
],
)
cc_test(
name = "embedding_lookup_test",
srcs = [
"embedding_lookup_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
......
......@@ -69,6 +69,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/exp_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/expand_dims_test.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/fill_test.cc \
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Ops that looks up items from matrix.
//
// Input:
// Tensor[0]: Row numbers to lookup, dim.size == 1, int32
// Tensor[1]: 2-dimensional matrix of multi-dimensional items
// dim.size >= 2, all items are INT8 or FLOAT32.
// first dimension is row, second dimension is column.
//
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
// Each item in output is a raw bytes copy of the corresponding item in input,
// or a dequantized value in the case of a INT8 input.
// When indices are out of bound, the ops will not succeed.
//
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
constexpr int kInputTensor_0 = 0;
constexpr int kInputTensor_1 = 1;
constexpr int kOutputTensor = 0;
struct OpData {
float scale; // quantization scale for tensor 1
size_t num_columns; // number of columns after flattening tensor 1 into 2D
};
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* tensor_1,
const TfLiteTensor* output) {
node->user_data = context->AllocatePersistentBuffer(context, sizeof(OpData));
OpData* op_data = static_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, op_data != nullptr);
if (tensor_1->type == kTfLiteInt8 && output->type == kTfLiteFloat32) {
TF_LITE_ENSURE_EQ(context, tensor_1->params.zero_point, 0);
op_data->scale = tensor_1->params.scale;
}
op_data->num_columns = NumElements(tensor_1) / tensor_1->dims->data[0];
return kTfLiteOk;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* lookup =
micro_context->AllocateTempInputTensor(node, kInputTensor_0);
TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
TfLiteTensor* value =
micro_context->AllocateTempInputTensor(node, kInputTensor_1);
TF_LITE_ENSURE(context, value != nullptr);
TF_LITE_ENSURE(context, NumDimensions(value) >= 2);
TF_LITE_ENSURE(context,
value->type == kTfLiteFloat32 || value->type == kTfLiteInt8);
TfLiteTensor* output =
micro_context->AllocateTempOutputTensor(node, kOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
if (value->type == kTfLiteFloat32) {
TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
} else {
TF_LITE_ENSURE(
context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt8);
}
// make sure output dimensions size can hold the new dimension data
TF_LITE_ENSURE(context, output->dims->size >= NumDimensions(value));
// make the output tensor dimensions mutable
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
// set the new output dimensions
output->dims->data[0] = SizeOfDimension(lookup, 0);
output->dims->data[1] = SizeOfDimension(value, 1);
for (int i = 2; i < NumDimensions(value); i++) {
output->dims->data[i] = SizeOfDimension(value, i);
}
// check the new output dimensions do not exceed the output data buffer size
size_t new_dims_size = NumElements(output) * TfLiteTypeGetSize(output->type);
TF_LITE_ENSURE(context, new_dims_size <= output->bytes);
TF_LITE_ENSURE_OK(context, CalculateOpData(context, node, value, output));
micro_context->DeallocateTempTfLiteTensor(lookup);
micro_context->DeallocateTempTfLiteTensor(value);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
TfLiteStatus EvalSimple(const OpData& op_data, const TfLiteEvalTensor* lookup,
const TfLiteEvalTensor* value,
TfLiteEvalTensor* output) {
const int num_rows = value->dims->data[0];
if (num_rows == 0) {
// Propagate empty tensor if input is empty
return kTfLiteOk;
}
const size_t row_bytes = op_data.num_columns * TfLiteTypeGetSize(value->type);
int8_t* output_raw = tflite::micro::GetTensorData<int8_t>(output);
const int8_t* value_raw = tflite::micro::GetTensorData<int8_t>(value);
const int32_t* lookup_data = tflite::micro::GetTensorData<int32_t>(lookup);
for (int i = 0; i < lookup->dims->data[0]; i++) {
int32_t idx = lookup_data[i];
if (idx >= num_rows || idx < 0) {
MicroPrintf(
"EMBEDDING_LOOKUP: index out of bounds. "
"Got %d, and bounds are [0, %d]",
idx, num_rows - 1);
return kTfLiteError;
} else {
std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes,
row_bytes);
}
}
return kTfLiteOk;
}
TfLiteStatus EvalHybrid(const OpData& op_data, const TfLiteEvalTensor* lookup,
const TfLiteEvalTensor* value,
TfLiteEvalTensor* output) {
const int num_rows = value->dims->data[0];
const size_t num_colums = op_data.num_columns;
float* output_ptr = tflite::micro::GetTensorData<float>(output);
const int8_t* value_ptr = tflite::micro::GetTensorData<int8_t>(value);
const int32_t* lookup_data = tflite::micro::GetTensorData<int32_t>(lookup);
for (int i = 0; i < lookup->dims->data[0]; i++) {
int32_t idx = lookup_data[i];
if (idx >= num_rows || idx < 0) {
MicroPrintf(
"EMBEDDING_LOOKUP: index out of bounds. "
"Got %d, and bounds are [0, %d]",
idx, num_rows - 1);
return kTfLiteError;
} else {
// Dequantize embedding values.
Dequantize(&value_ptr[idx * num_colums], num_colums, op_data.scale, 0,
&output_ptr[i * num_colums]);
}
}
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* lookup =
tflite::micro::GetEvalInput(context, node, kInputTensor_0);
const TfLiteEvalTensor* value =
tflite::micro::GetEvalInput(context, node, kInputTensor_1);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
OpData& op_data = *static_cast<OpData*>(node->user_data);
switch (value->type) {
case kTfLiteFloat32:
return EvalSimple(op_data, lookup, value, output);
case kTfLiteInt8:
if (output->type == kTfLiteFloat32) {
return EvalHybrid(op_data, lookup, value, output);
} else {
return EvalSimple(op_data, lookup, value, output);
}
default:
MicroPrintf("EMBEDDING_LOOKUP only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
} // namespace
TFLMRegistration Register_EMBEDDING_LOOKUP() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);
}
} // namespace tflite
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <iterator>
#include <type_traits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
constexpr float kTestTolerance = 7.41e-03;
constexpr int kNumInputs = 2;
constexpr int kNumOutputs = 1;
constexpr int kInputTensorIndex_0 = 0;
constexpr int kInputTensorIndex_1 = 1;
constexpr int kOutputTensorIndex = 2;
// min/max are used to compute scale, zero-point is 0
template <size_t kInputSize>
struct TestEmbeddingLookupParams {
// quantization parameters
float data_min; // input data minimum value
float data_max; // input data maximum value
int8_t input_data[kInputSize]; // quantized input storage
};
void ExecuteEmbeddingLookupTest(TfLiteTensor* tensors, int tensors_count) {
int kInputArrayData[] = {kNumInputs, kInputTensorIndex_0,
kInputTensorIndex_1};
TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex};
TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
const TFLMRegistration registration = tflite::Register_EMBEDDING_LOOKUP();
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
outputs_array, nullptr);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
template <size_t N>
void TestEmbeddingLookupQuantized(TestEmbeddingLookupParams<N>& params,
int* input_dims_data[kNumInputs],
const int32_t* input_data_0,
const float* input_data_1, int* expected_dims,
const float* expected_data,
float* output_data) {
TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]);
TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
const float scale =
SymmetricScaleFromMinMax<int8_t>(params.data_min, params.data_max);
TfLiteTensor tensors[] = {
CreateTensor(input_data_0, input_dims_0),
CreateQuantizedTensor(input_data_1, params.input_data, input_dims_1,
scale, 0),
CreateTensor(output_data, output_dims),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteEmbeddingLookupTest(tensors, tensors_count);
// check output data against expected
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTestTolerance);
}
// check output dimensions (relocated) against original dimensions
TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
tensors[kOutputTensorIndex].dims->size);
for (int i = 0; i < output_dims->size; i++) {
TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
tensors[kOutputTensorIndex].dims->data[i]);
}
} // namespace
template <typename T>
void TestEmbeddingLookup(int* input_dims_data[kNumInputs],
const int32_t* input_data_0, const T* input_data_1,
int* expected_dims, const T* expected_data,
T* output_data) {
TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]);
TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
TfLiteTensor tensors[] = {
CreateTensor(input_data_0, input_dims_0),
CreateTensor(input_data_1, input_dims_1),
CreateTensor(output_data, output_dims),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteEmbeddingLookupTest(tensors, tensors_count);
// check output data against expected
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTestTolerance);
}
// check output dimensions (relocated) against original dimensions
TF_LITE_MICRO_EXPECT_EQ(output_dims->size,
tensors[kOutputTensorIndex].dims->size);
for (int i = 0; i < output_dims->size; i++) {
TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i],
tensors[kOutputTensorIndex].dims->data[i]);
}
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(EmbeddingLookupOpTestSimpleFloat) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr float kExpect[] = {
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookup(kInputDims, kInput_0, kInput_1,
kOutputDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple2DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {2, 3, 8};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {2, 3, 8};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple3DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(HybridEmbeddingLookupHybridOpTestSimple4DTestInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {4, 3, 2, 2, 2};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {4, 3, 2, 2, 2};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr float kInput_1[] = {
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kInputCount_1 = std::extent<decltype(kInput_1)>::value;
constexpr float kExpect[] = {
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::TestEmbeddingLookupParams<kInputCount_1> params = {};
auto minmax = std::minmax_element(std::begin(kInput_1), std::end(kInput_1));
params.data_max = *minmax.second;
params.data_min = *minmax.first;
tflite::testing::TestEmbeddingLookupQuantized(params, kInputDims, kInput_0,
kInput_1, kOutputDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(EmbeddingLookupOpTestSimpleInt8) {
int kInputDims_0[] = {1, 3};
int kInputDims_1[] = {3, 3, 2, 4};
int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1};
int kOutputDims[] = {3, 3, 2, 4};
constexpr int32_t kInput_0[] = {1, 0, 2};
constexpr int8_t kInput_1[] = {
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
-56, -55, -54, -53, -46, -45, -44, -43, // Row 2
};
constexpr int8_t kExpect[] = {
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
-56, -55, -54, -53, -46, -45, -44, -43, // Row 2
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
int8_t output_data[kOutputCount];
tflite::testing::TestEmbeddingLookup(kInputDims, kInput_0, kInput_1,
kOutputDims, kExpect, output_data);
}
TF_LITE_MICRO_TESTS_END
......@@ -55,6 +55,7 @@ TFLMRegistration Register_DEPTHWISE_CONV_2D();
TFLMRegistration Register_DEQUANTIZE();
TFLMRegistration Register_DIV();
TFLMRegistration Register_ELU();
TFLMRegistration Register_EMBEDDING_LOOKUP();
TFLMRegistration Register_EQUAL();
TFLMRegistration* Register_ETHOSU();
TFLMRegistration Register_EXP();
......
......@@ -16,8 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_MICRO_TEST_HELPERS_H_
#define TENSORFLOW_LITE_MICRO_TEST_HELPERS_H_
#include <algorithm>
#include <cstdint>
#include <limits>
#include <type_traits>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
......@@ -298,7 +300,7 @@ TfLiteTensor CreateSymmetricPerChannelQuantizedTensor(
// Returns the number of tensors in the default subgraph for a tflite::Model.
size_t GetModelTensorCount(const Model* model);
// Derives the quantization scaling factor from a min and max range.
// Derives the asymmetric quantization scaling factor from a min and max range.
template <typename T>
inline float ScaleFromMinMax(const float min, const float max) {
return (max - min) /
......@@ -306,6 +308,19 @@ inline float ScaleFromMinMax(const float min, const float max) {
std::numeric_limits<T>::min());
}
// Derives the symmetric quantization scaling factor from a min and max range.
template <typename T>
inline float SymmetricScaleFromMinMax(const float min, const float max) {
const int32_t kScale =
std::numeric_limits<typename std::make_signed<T>::type>::max();
const float range = std::max(std::abs(min), std::abs(max));
if (range == 0) {
return 1.0f;
} else {
return range / kScale;
}
}
// Derives the quantization zero point from a min and max range.
template <typename T>
inline int ZeroPointFromMinMax(const float min, const float max) {
......
......@@ -331,6 +331,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/detection_postprocess.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ethosu.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/exp.cc \
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/expand_dims.cc \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册