未验证 提交 5a65a54c 编写于 作者: A Advait Jain 提交者: GitHub

Sync from upstream with `ci/sync_from_upstream_tf.sh` (#15)

上级 3d8b8e6c
......@@ -456,8 +456,8 @@ typedef struct TfLiteTensor {
} TfLiteTensor;
// A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not
// other features like the type.
// This structure only exhibits the inputs, outputs, user defined data and some
// node properties (like statefulness), not other features like the type.
typedef struct TfLiteNode {
// Inputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* inputs;
......@@ -490,6 +490,9 @@ typedef struct TfLiteNode {
// created by calling `interpreter.ModifyGraphWithDelegate`.
// WARNING: This is an experimental interface that is subject to change.
struct TfLiteDelegate* delegate;
// Whether this op might have side effect (e.g. stateful op).
bool might_have_side_effect;
} TfLiteNode;
#else // defined(TF_LITE_STATIC_MEMORY)?
// NOTE: This flag is opt-in only at compile time.
......@@ -640,6 +643,7 @@ typedef struct TfLiteContext {
// TfLiteDelegates can traverse the current execution plan by iterating
// through each member of this array and using GetNodeAndRegistration() to
// access details about a node. i.e.
//
// TfLiteIntArray* execution_plan;
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
// for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
......@@ -648,6 +652,28 @@ typedef struct TfLiteContext {
// TfLiteRegistration* reg;
// context->GetNodeAndRegistration(context, node_index, &node, &reg);
// }
// Note: the memory pointed by '`*execution_plan` is OWNED by TfLite runtime.
// Future calls to GetExecutionPlan invalidates earlier outputs. The following
// code snippet shows the issue of such an invocation pattern. After calling
// CheckNode, subsequent access to `plan_1st` is undefined.
//
// void CheckNode(const TfLiteNode* node) {
// ...
// TfLiteIntArray* plan_2nd;
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_2nd));
// ...
// }
//
// TfLiteIntArray* plan_1st;
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_1st));
// for (int exec_index = 0; exec_index < plan_1st->size; exec_index++) {
// int node_index = plan_1st->data[exec_index];
// TfLiteNode* node;
// TfLiteRegistration* reg;
// context->GetNodeAndRegistration(context, node_index, &node, &reg);
// CheckNode(node);
// }
//
// WARNING: This is an experimental interface that is subject to change.
TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);
......
......@@ -575,7 +575,8 @@ log_x_for_x_greater_than_or_equal_to_1_impl(
// InputIntegerBits - z_b_headroom - 0.25);
const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)),
static_cast<int32_t>(InputIntegerBits - z_a_headroom_plus_1),
31 - kAccumIntegerBits)),
shifted_quarter);
// z_b is treated like z_a, but premultiplying by sqrt(0.5).
......@@ -585,7 +586,8 @@ log_x_for_x_greater_than_or_equal_to_1_impl(
SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom);
const FixedPointAccum z_b_pow_2_adj = SaturatingSub(
FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam(
InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)),
static_cast<int32_t>(InputIntegerBits - z_b_headroom),
31 - kAccumIntegerBits)),
shifted_quarter);
const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw));
......
......@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_ADD_H_
#include <type_traits>
#include "fixedpoint/fixedpoint.h"
#include "tensorflow/lite/kernels/internal/common.h"
......@@ -27,25 +29,14 @@ inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] + input2_data[i], params.quantized_activation_min,
params.quantized_activation_max);
}
}
T activation_min, activation_max;
GetActivationParams(params, &activation_min, &activation_max);
inline void Add(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const float* input1_data,
const RuntimeShape& input2_shape, const float* input2_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
auto x = input1_data[i] + input2_data[i];
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
x, params.float_activation_min, params.float_activation_max);
input1_data[i] + input2_data[i], activation_min, activation_max);
}
}
......@@ -202,13 +193,12 @@ inline void Add(const ArithmeticParams& params,
}
}
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
template <typename T>
inline typename std::enable_if<!is_small_integer<T>::value, void>::type
BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
......@@ -216,45 +206,8 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
// trailing dimension changing most rapidly (channels has the smallest stride,
// typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
input2_data[SubscriptToIndex(desc2, b, y, x, c)],
params.float_activation_min, params.float_activation_max);
}
}
}
}
}
inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int32_t* input1_data,
const RuntimeShape& input2_shape,
const int32_t* input2_data,
const RuntimeShape& output_shape,
int32_t* output_data) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
T activation_min, activation_max;
GetActivationParams(params, &activation_min, &activation_max);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
......@@ -272,11 +225,10 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
output_data[Offset(extended_output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
ActivationFunctionWithMinMax<T>(
input1_data[SubscriptToIndex(desc1, b, y, x, c)] +
input2_data[SubscriptToIndex(desc2, b, y, x, c)],
params.quantized_activation_min,
params.quantized_activation_max);
activation_min, activation_max);
}
}
}
......@@ -287,10 +239,11 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
// is 32-bit for both cases. The overflow does not happen due to the
// choice of the shift (20 or 15, accordingly - see add.cc for more comments).
template <typename T>
inline void BroadcastAdd4DSlow(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
inline typename std::enable_if<is_small_integer<T>::value, void>::type
BroadcastAdd4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
......
......@@ -15,10 +15,12 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CUMSUM_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CUMSUM_H_
#include <algorithm>
#include <cstdint>
#include <limits>
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
......@@ -79,6 +81,94 @@ inline void CumSum(const T* input_data, const RuntimeShape& shape, int32_t axis,
}
}
//
// Quantized INT8 CUMSUM
//
inline void CumSum(const ArithmeticParams& params, const int8_t* input_data,
const RuntimeShape& shape, int32_t axis, bool exclusive,
bool reverse, int8_t* output_data) {
TFLITE_DCHECK_LE(params.quantized_activation_min,
params.quantized_activation_max);
// Input offset is negative input zero point. Activation tensors are
// asymmetric quantized so they span the full int8 range.
// All inputs should have same zero-point and scale, this is checked during
// Prepare stage.
TFLITE_DCHECK_GE(-params.input1_offset, std::numeric_limits<int8_t>::min());
TFLITE_DCHECK_LE(-params.input1_offset, std::numeric_limits<int8_t>::max());
const int32_t rank = shape.DimensionsCount();
TFLITE_DCHECK_GE(rank, 1);
TFLITE_DCHECK_GE(axis, 0);
TFLITE_DCHECK_LT(axis, rank);
size_t inner = 1;
size_t outer = 1;
size_t depth = 1;
for (int32_t i = 0; i < rank; i++) {
if (i < axis)
inner *= shape.Dims(i);
else if (i > axis)
outer *= shape.Dims(i);
else
depth = shape.Dims(i);
}
for (size_t outer_index = 0; outer_index < outer; outer_index++) {
size_t outer_index_adj;
if (reverse)
outer_index_adj = (outer - 1) - outer_index;
else
outer_index_adj = outer_index;
for (size_t inner_index = 0; inner_index < inner; inner_index++) {
int32_t accumulator = params.input1_offset; // accumulator = 0
accumulator *= (1 << params.left_shift);
accumulator = MultiplyByQuantizedMultiplierSmallerThanOneExp(
accumulator, params.input1_multiplier, params.input1_shift);
size_t inner_index_adj;
if (reverse)
inner_index_adj = (inner - 1) - inner_index;
else
inner_index_adj = inner_index;
for (size_t depth_index = 0; depth_index < depth; depth_index++) {
size_t depth_index_adj;
if (reverse)
depth_index_adj = (depth - 1) - depth_index;
else
depth_index_adj = depth_index;
size_t index = outer_index_adj;
index += inner_index_adj * depth * outer;
index += depth_index_adj * outer;
const int32_t y = params.input1_offset + input_data[index];
const int32_t shifted_y = y * (1 << params.left_shift);
const int32_t scaled_y = MultiplyByQuantizedMultiplierSmallerThanOneExp(
shifted_y, params.input1_multiplier, params.input1_shift);
int32_t scaled_output;
if (exclusive) {
scaled_output = accumulator;
accumulator += scaled_y;
} else {
accumulator += scaled_y;
scaled_output = accumulator;
}
const int32_t raw_output =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
scaled_output, params.output_multiplier, params.output_shift) +
params.output_offset;
const int32_t clamped_output =
std::min(params.quantized_activation_max,
std::max(params.quantized_activation_min, raw_output));
output_data[index] = static_cast<int8_t>(clamped_output);
}
}
}
}
} // namespace reference_ops
} // namespace tflite
......
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTH_TO_SPACE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTH_TO_SPACE_H_
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
namespace reference_ops {
template <typename T>
inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
const int input_width = input_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_batch = input_shape.Dims(0);
const int output_depth = output_shape.Dims(3);
const int output_width = output_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_batch = output_shape.Dims(0);
const int32_t block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width * block_size, output_width);
TFLITE_DCHECK_EQ(input_height * block_size, output_height);
TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size);
TFLITE_DCHECK_EQ(input_batch, output_batch);
for (int out_b = 0; out_b < output_batch; ++out_b) {
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
for (int out_d = 0; out_d < output_depth; ++out_d) {
const int in_d =
out_d + ((out_h % block_size) * block_size + out_w % block_size) *
output_depth;
const int in_w = out_w / block_size;
const int in_h = out_h / block_size;
const int in_b = out_b;
const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
}
}
}
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTH_TO_SPACE_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LOG_SOFTMAX_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LOG_SOFTMAX_H_
#include <algorithm>
#include <cstddef>
#include <limits>
#include "fixedpoint/fixedpoint.h"
#include "tensorflow/lite/kernels/internal/common.h"
namespace tflite {
namespace reference_ops {
inline void LogSoftmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
// taking advantage of the following equality:
// log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C)))
float max = std::numeric_limits<float>::lowest();
for (int c = 0; c < depth; ++c) {
max = std::max(max, input_data[i * depth + c]);
}
// Compute sum.
float sum = 0.f;
for (int c = 0; c < depth; ++c) {
sum += std::exp(input_data[i * depth + c] - max);
}
// Compute result.
const float log_sum = std::log(sum);
for (int c = 0; c < depth; ++c) {
output_data[i * depth + c] = input_data[i * depth + c] - max - log_sum;
}
}
}
inline void LogSoftmax(const SoftmaxParams& params,
const RuntimeShape& input_shape,
const uint8_t* input_data,
const RuntimeShape& output_shape, uint8_t* output_data) {
const int32_t input_multiplier = params.input_multiplier;
const int32_t input_left_shift = params.input_left_shift;
const int32_t reverse_scaling_divisor = params.reverse_scaling_divisor;
const int32_t reverse_scaling_right_shift =
params.reverse_scaling_right_shift;
const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large
// as -32 before multiplying by input_beta_multiplier, and therefore as
// large as -16 afterwards. Note that exp(-8) is definitely not
// insignificant to accumulation, but exp(-16) definitely is.
static constexpr int kScaledDiffIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
using FixedPointScaledDiff =
gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
using FixedPointAccum =
gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8_t max_in_row = 0;
for (int c = 0; c < depth; ++c) {
max_in_row = std::max(max_in_row, input_data[i * depth + c]);
}
FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
for (int c = 0; c < depth; ++c) {
int32_t input_diff =
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_multiplier, input_left_shift);
const FixedPointScaledDiff scaled_diff_f8 =
FixedPointScaledDiff::FromRaw(input_diff_rescaled);
sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(scaled_diff_f8));
}
}
const int32_t fixed_log_sum_of_exps =
log_x_for_x_greater_than_or_equal_to_1<kScaledDiffIntegerBits>(
sum_of_exps)
.raw();
// rescaled_diff_min is smallest representable in
// Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the
// log-sub-exps that will be subtracted in the loop.
//
// The thresholds diff_min, etc are negative.
const int rescaled_diff_min =
fixed_log_sum_of_exps + std::numeric_limits<int32_t>::lowest();
const int adjusted_diff_min =
std::max(static_cast<int32_t>(
diff_min - 1), // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
-reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32_t input_diff =
static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
if (input_diff > adjusted_diff_min) {
const int32_t input_diff_rescaled =
MultiplyByQuantizedMultiplierGreaterThanOne(
input_diff, input_multiplier, input_left_shift);
int32_t unsat_output =
gemmlowp::RoundingDivideByPOT(
(input_diff_rescaled - fixed_log_sum_of_exps),
31 - kScaledDiffIntegerBits - kOutputIntegerBits) +
255;
output_data[i * depth + c] = static_cast<uint8_t>(
std::max(std::min(unsat_output, static_cast<int32_t>(255)),
static_cast<int32_t>(0)));
} else {
// Set output to smallest value.
output_data[i * depth + c] = 0;
}
}
}
}
template <typename T>
inline void LogSoftmaxQuantized(const SoftmaxParams& params,
const size_t outer_size, const size_t depth,
const RuntimeShape& input_shape,
const T* input_data,
const RuntimeShape& output_shape,
T* output_data) {
const int32_t input_multiplier = params.input_multiplier;
const int32_t input_left_shift = params.input_left_shift;
const int32_t reverse_scaling_divisor = params.reverse_scaling_divisor;
const int32_t reverse_scaling_right_shift =
params.reverse_scaling_right_shift;
const int diff_min = params.diff_min;
static constexpr T kMinT8 = std::numeric_limits<T>::min();
static constexpr T kMaxT8 = std::numeric_limits<T>::max();
static constexpr int32_t kMinInt32 = std::numeric_limits<int32_t>::min();
// All IntegerBits must agree with Prepare function.
// Input is chosen as Q5.26 so exp(-1 * 2^5 * 2^-1) = exp(-16) is negligible.
static constexpr int kInputIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
using F5 = gemmlowp::FixedPoint<int32_t, kInputIntegerBits>;
using F12 = gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
for (size_t outer_index = 0; outer_index < outer_size; ++outer_index) {
T max_in_row = kMinT8;
for (size_t inner_index = 0; inner_index < depth; ++inner_index) {
max_in_row =
std::max(max_in_row, input_data[outer_index * depth + inner_index]);
}
// Accumulator "sum_of_exps_in_q12" is safe from overflowing in 2^12 steps.
F12 sum_of_exps_in_q12 = F12::FromRaw(0);
for (size_t inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input_diff =
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
max_in_row;
if (input_diff >= diff_min) {
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
input_diff, input_multiplier, input_left_shift);
sum_of_exps_in_q12 =
sum_of_exps_in_q12 +
gemmlowp::Rescale<kAccumulationIntegerBits>(
exp_on_negative_values(F5::FromRaw(input_diff_in_q5)));
}
}
const int32_t log_sum_of_exps_in_q5 =
log_x_for_x_greater_than_or_equal_to_1<kInputIntegerBits>(
sum_of_exps_in_q12)
.raw();
// Potentially reduced the valid range. shifted_log_sum_of_exps_in_q5 is
// smallest representable in Q5.26 plus the log_sum_of_exps.
const int32_t shifted_log_sum_of_exps_in_q5 =
log_sum_of_exps_in_q5 + kMinInt32;
const int32_t adjusted_diff_min =
std::max(static_cast<int32_t>(diff_min - 1),
MultiplyByQuantizedMultiplier(shifted_log_sum_of_exps_in_q5,
reverse_scaling_divisor,
-reverse_scaling_right_shift));
for (size_t inner_index = 0; inner_index < depth; ++inner_index) {
int32_t input_diff =
static_cast<int32_t>(input_data[outer_index * depth + inner_index]) -
max_in_row;
// Note use of > below instead of >= above.
if (input_diff > adjusted_diff_min) {
const int32_t input_diff_in_q5 = MultiplyByQuantizedMultiplier(
input_diff, input_multiplier, input_left_shift);
// Rescale and downcast.
int32_t output_in_q27 =
gemmlowp::RoundingDivideByPOT(
(input_diff_in_q5 - log_sum_of_exps_in_q5),
31 - kInputIntegerBits - kOutputIntegerBits) +
kMaxT8;
output_in_q27 =
std::max(std::min(output_in_q27, static_cast<int32_t>(kMaxT8)),
static_cast<int32_t>(kMinT8));
output_data[outer_index * depth + inner_index] =
static_cast<T>(output_in_q27);
} else {
output_data[outer_index * depth + inner_index] = kMinT8;
}
}
}
}
inline void LogSoftmax(const SoftmaxParams& params, const size_t outer_size,
const size_t depth, const RuntimeShape& input_shape,
const int8_t* input_data,
const RuntimeShape& output_shape, int8_t* output_data) {
LogSoftmaxQuantized(params, outer_size, depth, input_shape, input_data,
output_shape, output_data);
}
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LOG_SOFTMAX_H_
......@@ -1196,6 +1196,23 @@ inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
*min = params.int64_activation_min;
*max = params.int64_activation_max;
}
// Type trait to check of given type has size smaller than 4 bytes.
template <typename T>
struct is_small_integer
: public std::integral_constant<bool,
std::is_same<T, int8_t>::value ||
std::is_same<T, uint8_t>::value ||
std::is_same<T, int16_t>::value ||
std::is_same<T, uint16_t>::value> {};
// Type trait to check of given type is int32 or int64.
template <typename T>
struct is_int32_or_int64
: public std::integral_constant<bool, std::is_same<T, int32_t>::value ||
std::is_same<T, int64_t>::value> {
};
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
......@@ -20,7 +20,6 @@ limitations under the License.
namespace tflite {
// TODO(renjieliu): Migrate others to use ComputePaddingWithLeftover.
inline int ComputePadding(int stride, int dilation_rate, int in_size,
int filter_size, int out_size) {
int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
......
......@@ -69,4 +69,8 @@ project, we have additional documentation in the [docs](docs/) folder.
* [Optimized Kernel Implementations](docs/optimized_kernel_implementations.md)
* [New Platform Support](docs/new_platform_support.md)
* [Software Emulation with Renode](docs/renode.md)
* [Pre-allocated tensors](docs/preallocated_tensors.md)
# RFCs
1. [Pre-allocated tensors](docs/rfc/001_preallocated_tensors.md)
1. [TensorFlow Lite for Microcontrollers Port of 16x8 Quantized Operators](docs/rfc/002_16x8_quantization_port.md)
......@@ -33,6 +33,7 @@ AllOpsResolver::AllOpsResolver() {
AddConv2D();
AddCos();
AddCumSum();
AddDepthToSpace();
AddDepthwiseConv2D();
AddDequantize();
AddDetectionPostprocess();
......
......@@ -23,6 +23,7 @@ cc_library(
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:micro_time",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:recording_allocators",
],
)
......
......@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/lite/micro/benchmarks/micro_benchmark.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/micro/system_setup.h"
......@@ -103,4 +102,6 @@ int main(int argc, char** argv) {
tflite::KeywordRunNIerations(10, "KeywordRunNIerations(10)",
*benchmark_runner, profiler);
MicroPrintf(""); // null MicroPrintf serves as a newline.
benchmark_runner->PrintAllocations();
}
......@@ -19,10 +19,10 @@ limitations under the License.
#include <climits>
#include "tensorflow/lite/micro/micro_error_reporter.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_op_resolver.h"
#include "tensorflow/lite/micro/micro_profiler.h"
#include "tensorflow/lite/micro/micro_time.h"
#include "tensorflow/lite/micro/recording_micro_interpreter.h"
namespace tflite {
......@@ -73,8 +73,12 @@ class MicroBenchmarkRunner {
}
}
void PrintAllocations() const {
interpreter_.GetMicroAllocator().PrintAllocations();
}
private:
tflite::MicroInterpreter interpreter_;
tflite::RecordingMicroInterpreter interpreter_;
};
} // namespace tflite
......
......@@ -166,7 +166,7 @@ if (interpreter.Invoke() != kTfLiteOk) {
}
// Print out detailed allocation information:
interpreter.PrintAllocations();
interpreter.GetMicroAllocator().PrintAllocations();
```
The output of this call will look something similar to this (output from the
......
......@@ -40,7 +40,7 @@ Makefile to `tensorflow/lite/micro/tools/make/downloads/renode`.
The Makefile internally calls the `renode_download.sh` script:
```
tensorflow/lite/micro/testing/renode_download.sh tensorflow/lite/micro/tools/make/downloads
tensorflow/lite/micro/tools/make/renode_download.sh tensorflow/lite/micro/tools/make/downloads
```
# Running Unit Tests
......
<!-- mdformat off(b/169948621#comment2) -->
# TensorFlow Lite for Microcontrollers Port of 16x8 Quantized Operators
| Status | Proposed |
:-------------- |:----------------------------------------------------------- |
| **RFC #2** | [46767](https://github.com/tensorflow/tensorflow/pull/46767)|
| **Author(s)** | Daniel Situnayake (dan@edgeimpulse.com) |
| **Sponsor** | Pete Warden (petewarden@google.com) |
| **Updated** | 2021-01-28 |
## Objective
TensorFlow Lite has kernel implementations that support 8 bit quantized weights
but use 16 bit activations. We wish to port these implementations to TensorFlow
Lite for Microcontrollers. The increased precision available for activations can
improve performance for some quantized models.
Arm have agreed to support the initiative by adding the necessary 16x8 APIs to
CMSIS-NN and porting the CMSIS-NN kernels.
### Goals
- Port a subset of 16x8 reference kernels from TensorFlow Lite to TensorFlow Lite Micro
- Avoid increasing default code size or arena size of TensorFlow Lite Micro
- Lay the groundwork for creating a CMSIS-NN port of the 16x8 kernels
### Non-goals
- Port every single operator to 16x8; we only plan to port a subset of those with existing reference implementations
## Motivation
Some networks that suffer unacceptable degradation when quantized with 8 bit weights
and 8 bit activations perform adequately when quantized with 8 bit weights and 16
bit activations. The [TensorFlow Lite documentation](https://www.tensorflow.org/lite/performance/post_training_integer_quant_16x8) states the following:
> [16x8 quantization] mode can improve accuracy of the quantized model significantly, when activations are sensitive to the quantization, while still achieving almost 3-4x reduction in model size. Moreover, this fully quantized model can be consumed by integer-only hardware accelerators.
Edge Impulse, a company that deploys TensorFlow Lite for Microcontrollers as part of its embedded
machine learning pipeline, has gathered feedback from customers with production models for which 8 bit
quantization results in unacceptable degradation but for whom 16x8 is fine.
While 16x8 quantization is well supported within TensorFlow Lite, it is not currently supported
within TensorFlow Lite for Microcontrollers. Porting the TensorFlow Lite reference kernels is
relatively straightforward and will improve adoption of TensorFlow Lite for Microcontrollers with users
for whom degradation is too severe with full 8 bit quantization.
## User Benefit
The headline would be "16x8 kernels improve accuracy for quantized models on microcontrollers without
increasing model size".
Users would benefit in the following ways:
- Improved accuracy for quantized models without increasing model size (in exchange for additional
runtime memory usage)
- Improved performance under certain conditions (for example, 16x8 CMSIS-NN kernels will run faster)
than 8 bit kernels since less unpacking is required)
## Design Proposal
We propose that the 16x8 kernels are ported from the TensorFlow Lite reference kernels to
TensorFlow Lite for Microcontrollers following the process in the [Porting TensorFlow Lite Ops to Micro](https://docs.google.com/document/d/1KLJTPWm4TUKB9YyIqFJl9VCP0ZMJDt_P8RNpRmwqMxw/edit#heading=h.5x0d5h95i329)
guide.
We wish to ensure that the following kernels are compatible with 16x8 mode:
- Conv2D
- MaxPool2D
- DepthwiseConv2D
- FullyConnected
- Relu
- Relu6
- Tanh
- Softmax
- Pad
- Reshape
- Pack
- Unpack
- Add
- Mul
Adding the 16x8 kernels directly to TFLM alongside the existing kernels would increase the default code size by an unacceptable amount. Instead, we will make use of the kernel registration API currently under development by the TFLM team. The use of this is demonstrated in the
[Keyword benchmark code](https://github.com/tensorflow/tensorflow/blob/a30d20b632b4ffbfd437ccf8ee205fef0917a3eb/tensorflow/lite/micro/benchmarks/keyword_benchmark.cc#L56).
By doing this, the end user can decide which kernels and dependencies they want to include (e.g. 8 bit, 16x8,
or float32).
For example, the following could be registered:
```
// Support for all datatypes
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED);
// Support for 8 bit quantized models
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT8);
// Support for 16x8 quantized models
op_resolver->AddFullyConnected(tflite::Register_FULLY_CONNECTED_INT16X8());
```
This means that kernels not currently using this registration API will need to be refactored to use it. Currently only **FullyConnected** uses the API.
The following associated tasks will be required to support this work:
- Build or port unit tests for the new kernels
- Prove that code memory is not impacted by running benchmarks before and after the port
### Alternatives Considered
* An alternative would be to add the 16x8 kernels without using the new kernel registration API, but this would
result in a major increase in code size.
### Performance Implications
- Impact on memory usage for current modes (int8 and float32) will be minimal. This will be confirmed by
benchmarking of current performance against performance of the submitted changes.
- When 16x8 mode is used, RAM usage will be approximately 2x. Latency may change depending on the target
platform.
- End to end and unit tests will be updated to prove that the new implementations are operating correctly.
### Dependencies
- No additional dependencies will be added to TensorFlow
- No other parts of TensorFlow will be affected
### Engineering Impact
- Impact on binary size should be minimal
- Test times may increase due to additional kernel unit tests
- The reference kernels already exist within TensorFlow Lite so there will be minimal additional maintenance
### Platforms and Environments
- The proposed changes will work on all currently supported platforms
### Best Practices
- TensorFlow Lite for Microcontrollers should be updated to indicate that 16x8 kernels are now available
### Tutorials and Examples
- A benchmark will be added to [`tensorflow/lite/micro/benchmarks`](https://github.com/tensorflow/tensorflow/tree/975335bc83bf3cb80a71a04ed407725508709808/tensorflow/lite/micro/benchmarks) that demonstrates the use of the ops that provide a 16x8 kernel.
- A Colab will be created that demonstrates quantizing a model in 16x8 mode and exporting it as a C header file for use with TensorFlow Lite for Microcontrollers
### Compatibility
- This work will improve compatibility and feature parity between TensorFlow Lite and TensorFlow Lite for Microcontrollers
### User Impact
- Since TFLM does not have a versioning system the feature can be rolled out as any other commit
## Implementation plan
The work will be broken down into a series of pull requests, some for the benchmarks and some for each kernel.
Benchmark pull requests:
- PR1: Create a new benchmark in [`tensorflow/lite/micro/benchmarks`](https://github.com/tensorflow/tensorflow/tree/975335bc83bf3cb80a71a04ed407725508709808/tensorflow/lite/micro/benchmarks) that attempts to run a 16x8 model that includes the kernels mentioned in this RFC. The model’s weights and biases can be random. The benchmark should use the MicroMutableOpResolver. The PR should include the Colab used to generate the model.
- PR2: Port the person_detection and keyword benchmarks to use the MicroMutableOpResolver.
- PR3: Add code to both benchmarks that prints the arena size using the [`RecordingMemoryAllocator`](https://github.com/tensorflow/tensorflow/blob/ee87d58a6504375c28f21ea303f0eefa29118c38/tensorflow/lite/micro/docs/memory_management.md#recording-memory-apis).
For each kernel:
- PR1: Refactor the implementation to support the new kernel variant registration API.
- PR2: Add 16x8 support and make sure that the benchmark binary and arena sizes are unchanged.
Note that @njeffrie from the TF Lite Micro team also plans to prepare PR(s) for the kernels that are of interest internally
(without using the kernel variant registation API for binary size). This will provide some quick examples of porting the kernels.
## Questions and Discussion Topics
......@@ -162,7 +162,7 @@ def train_net(
# Convert the model to the TensorFlow Lite format with quantization
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the model to disk
......
......@@ -598,7 +598,7 @@ This will take a few minutes, and downloads frameworks the code uses like
finished, run:
```
make -f tensorflow/lite/micro/tools/make/Makefile test_person_detection_test
make -f tensorflow/lite/micro/tools/make/Makefile test_person_detection_test_int8
```
You should see a series of files get compiled, followed by some logging output
......
......@@ -267,6 +267,7 @@ cc_library(
"comparisons.cc",
"concatenation.cc",
"cumsum.cc",
"depth_to_space.cc",
"dequantize.cc",
"detection_postprocess.cc",
"elementwise.cc",
......@@ -282,6 +283,7 @@ cc_library(
"leaky_relu.cc",
"logical.cc",
"logistic.cc",
"log_softmax.cc",
"maximum_minimum.cc",
"mul.cc",
"neg.cc",
......@@ -555,6 +557,21 @@ cc_test(
],
)
cc_test(
name = "depth_to_space_test",
srcs = [
"depth_to_space_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "depthwise_conv_test",
srcs = [
......@@ -805,6 +822,21 @@ cc_test(
],
)
cc_test(
name = "log_softmax_test",
srcs = [
"log_softmax_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
cc_test(
name = "maximum_minimum_test",
srcs = [
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/cumsum.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
......@@ -23,9 +24,24 @@ limitations under the License.
namespace tflite {
namespace {
static const int kInputTensor = 0;
static const int kAxisTensor = 1;
static const int kOutputTensor = 0;
constexpr int kInputTensor = 0;
constexpr int kAxisTensor = 1;
constexpr int kOutputTensor = 0;
constexpr int kCumSumIntegerShift = 20;
// only used with INT8 tensors
struct OpData {
int32_t output_activation_min;
int32_t output_activation_max;
int32_t input_offset;
int32_t output_offset;
int32_t input_multiplier;
int32_t output_multiplier;
int input_shift;
int output_shift;
int left_shift;
};
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
......@@ -34,7 +50,8 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxisTensor);
TF_LITE_ENSURE(context, input->type == kTfLiteFloat32);
TF_LITE_ENSURE(context,
input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, axis->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
......@@ -46,6 +63,34 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context, HaveSameShapes(input, output));
if (output->type == kTfLiteInt8) {
node->user_data =
context->AllocatePersistentBuffer(context, sizeof(OpData));
OpData* data = static_cast<OpData*>(node->user_data);
// 8bit -> 8bit general quantized path, with general rescalings
data->input_offset = -input->params.zero_point;
data->output_offset = output->params.zero_point;
data->left_shift = kCumSumIntegerShift;
const double twice_max_input_scale =
2 * static_cast<double>(input->params.scale);
const double real_input_multiplier =
static_cast<double>(input->params.scale) / twice_max_input_scale;
const double real_output_multiplier =
twice_max_input_scale /
((1 << data->left_shift) * static_cast<double>(output->params.scale));
QuantizeMultiplierSmallerThanOneExp(
real_input_multiplier, &data->input_multiplier, &data->input_shift);
QuantizeMultiplierSmallerThanOneExp(
real_output_multiplier, &data->output_multiplier, &data->output_shift);
TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
context, kTfLiteActNone, output, &data->output_activation_min,
&data->output_activation_max));
}
return kTfLiteOk;
}
......@@ -62,7 +107,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
auto* params = static_cast<TfLiteCumsumParams*>(node->builtin_data);
auto* cs_params = static_cast<TfLiteCumsumParams*>(node->builtin_data);
auto input_shape = tflite::micro::GetTensorShape(input);
int32_t axis = *tflite::micro::GetTensorData<int32_t>(axis_tensor);
......@@ -76,14 +121,35 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input->type) {
case kTfLiteFloat32: {
reference_ops::CumSum(tflite::micro::GetTensorData<float>(input),
input_shape, axis, params->exclusive,
params->reverse,
input_shape, axis, cs_params->exclusive,
cs_params->reverse,
tflite::micro::GetTensorData<float>(output));
return kTfLiteOk;
} break;
case kTfLiteInt8: {
auto* data = static_cast<OpData*>(node->user_data);
ArithmeticParams params;
params.left_shift = data->left_shift;
params.input1_offset = data->input_offset;
params.input1_multiplier = data->input_multiplier;
params.input1_shift = data->input_shift;
params.output_offset = data->output_offset;
params.output_multiplier = data->output_multiplier;
params.output_shift = data->output_shift;
SetActivationParams(data->output_activation_min,
data->output_activation_max, &params);
reference_ops::CumSum(params, tflite::micro::GetTensorData<int8_t>(input),
input_shape, axis, cs_params->exclusive,
cs_params->reverse,
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
} break;
default: {
TF_LITE_KERNEL_LOG(
context, "Unsupported input type, CUMSUM only supports FLOAT32.");
TF_LITE_KERNEL_LOG(context,
"CUMSUM only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
......
......@@ -77,6 +77,59 @@ void TestCumSum(const CumSumTestParams& test_params, const int* input_dims_data,
}
}
// min/max are used to compute scale, zero-point, compare tolerance
template <typename T, int kOutputSize>
struct TestQuantParams {
float data_min; // input and output data minimum value
float data_max; // input and output data maximum value
T input_data[kOutputSize]; // quantized input storage
T output_data[kOutputSize]; // quantized output storage
};
// for quantized int, the error shouldn't exceed step
template <typename T>
float GetTolerance(float min, float max) {
float kQuantizedStep =
2.0f * (max - min) /
(std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
return kQuantizedStep;
}
template <typename T, int kOutputSize>
void TestCumSumQuantized(const CumSumTestParams& test_params,
TestQuantParams<T, kOutputSize>* params,
const int* input_dims_data, const float* input_data,
const int* expected_dims, const float* expected_data,
float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
constexpr int axis_dims_data[] = {1, 1};
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
const int32_t axis_data[] = {test_params.axis};
const float scale = ScaleFromMinMax<T>(params->data_min, params->data_max);
const int zero_point =
ZeroPointFromMinMax<T>(params->data_min, params->data_max);
TfLiteTensor tensors[] = {
CreateQuantizedTensor(input_data, params->input_data, input_dims, scale,
zero_point),
CreateTensor(axis_data, axis_dims),
CreateQuantizedTensor(params->output_data, output_dims, scale,
zero_point),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteCumSumTest(test_params, tensors, tensors_count);
Dequantize(params->output_data, kOutputSize, scale, zero_point, output_data);
const float kTolerance = GetTolerance<T>(params->data_min, params->data_max);
for (int i = 0; i < kOutputSize; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance);
}
}
} // namespace
} // namespace testing
} // namespace tflite
......@@ -177,4 +230,121 @@ TF_LITE_MICRO_TEST(CumSumOpTestSimpleReverseExclusiveTest) {
output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimpleTestInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {1, 3, 6, 10, 5, 11, 18, 26};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = 1;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -26.0f;
params.data_max = 26.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimpleAxis0TestInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {1, 2, 3, 4, 6, 8, 10, 12};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = 0;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -12.0f;
params.data_max = 12.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimple1DTestInt8) {
constexpr int kDims[] = {1, 8};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {1, 3, 6, 10, 15, 21, 28, 36};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = 0;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -36.0f;
params.data_max = 36.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimpleReverseTestInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {10, 9, 7, 4, 26, 21, 15, 8};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = 1;
test_params.reverse = true;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -26.0f;
params.data_max = 26.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimpleExclusiveTestInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {0, 1, 3, 6, 0, 5, 11, 18};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = 1;
test_params.exclusive = true;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -18.0f;
params.data_max = 18.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TEST(CumSumOpTestSimpleReverseExclusiveTestInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr float kExpect[] = {9, 7, 4, 0, 21, 15, 8, 0};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::CumSumTestParams test_params;
test_params.axis = -1;
test_params.exclusive = true;
test_params.reverse = true;
tflite::testing::TestQuantParams<int8_t, kOutputCount> params = {};
params.data_min = -21.0f;
params.data_max = 21.0f;
tflite::testing::TestCumSumQuantized<int8_t, kOutputCount>(
test_params, &params, kDims, kInput, kDims, kExpect, output_data);
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
#include <stdint.h>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace builtin {
namespace depth_to_space {
// This file has two implementation of DepthToSpace. Note that DepthToSpace only
// works on 4D tensors.
enum KernelType {
kReference,
kGenericOptimized,
};
namespace {
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// input/output tensor shape rank associations
constexpr int kBatchRank = 0;
constexpr int kHeightRank = 1;
constexpr int kWidthRank = 2;
constexpr int kDepthRank = 3;
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
......@@ -55,15 +50,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto data_type = output->type;
TF_LITE_ENSURE(context,
data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 ||
data_type == kTfLiteInt8 || data_type == kTfLiteInt32 ||
data_type == kTfLiteInt64);
data_type == kTfLiteFloat32 || data_type == kTfLiteInt8);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
const int block_size = params->block_size;
const int input_height = input->dims->data[1];
const int input_width = input->dims->data[2];
const int input_channels = input->dims->data[3];
const int input_height = input->dims->data[kHeightRank];
const int input_width = input->dims->data[kWidthRank];
const int input_channels = input->dims->data[kDepthRank];
int output_height = input_height * block_size;
int output_width = input_width * block_size;
int output_channels = input_channels / block_size / block_size;
......@@ -73,98 +66,77 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input_channels,
output_channels * block_size * block_size);
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = input->dims->data[0];
output_size->data[1] = output_height;
output_size->data[2] = output_width;
output_size->data[3] = output_channels;
// We must update the output tensor dimensions.
// The dims storage is expected to be the same area in memory
// for both TfLiteTensor and TfLiteEvalTensor. This is important
// because TfLiteTensor in the MicroInterpreter is a temporary
// allocation. For the KernelRunner interpreter, TfLiteEvalTensor
// is a temporary allocation. We must therefore relocate the dims
// from the FlatBuffer to the persistant storage arena.
TfLiteEvalTensor* output_eval =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy(
context, output, output_eval));
output->dims->data[kBatchRank] = input->dims->data[kBatchRank];
output->dims->data[kHeightRank] = output_height;
output->dims->data[kWidthRank] = output_width;
output->dims->data[kDepthRank] = output_channels;
return context->ResizeTensor(context, output, output_size);
return kTfLiteOk;
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteDepthToSpaceParams*>(node->builtin_data);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
tflite::DepthToSpaceParams op_params;
op_params.block_size = static_cast<int32_t>(params->block_size);
#define TF_LITE_DEPTH_TO_SPACE(type, scalar) \
tflite::DepthToSpaceParams op_params; \
op_params.block_size = params->block_size; \
type::DepthToSpace(op_params, GetTensorShape(input), \
GetTensorData<scalar>(input), GetTensorShape(output), \
GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
TF_LITE_DEPTH_TO_SPACE(reference_ops, float);
} else {
TF_LITE_DEPTH_TO_SPACE(optimized_ops, float);
}
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_DEPTH_TO_SPACE(reference_ops, uint8_t);
} else {
TF_LITE_DEPTH_TO_SPACE(optimized_ops, uint8_t);
}
reference_ops::DepthToSpace(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
break;
case kTfLiteInt8:
if (kernel_type == kReference) {
TF_LITE_DEPTH_TO_SPACE(reference_ops, int8_t);
} else {
TF_LITE_DEPTH_TO_SPACE(optimized_ops, int8_t);
}
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_DEPTH_TO_SPACE(reference_ops, int32_t);
} else {
TF_LITE_DEPTH_TO_SPACE(optimized_ops, int32_t);
}
break;
case kTfLiteInt64:
if (kernel_type == kReference) {
TF_LITE_DEPTH_TO_SPACE(reference_ops, int64_t);
} else {
TF_LITE_DEPTH_TO_SPACE(optimized_ops, int64_t);
}
reference_ops::DepthToSpace(op_params,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
default:
TF_LITE_KERNEL_LOG(context, "Type '%s' not currently supported.",
TfLiteTypeGetName(input->type));
TF_LITE_KERNEL_LOG(
context, "DEPTH_TO_SPACE only supports FLOAT32 and INT8, got %s.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
#undef TF_LITE_DEPTH_TO_SPACE
return kTfLiteOk;
}
} // namespace depth_to_space
TfLiteRegistration* Register_DEPTH_TO_SPACE_REF() {
static TfLiteRegistration r = {
nullptr, nullptr, depth_to_space::Prepare,
depth_to_space::Eval<depth_to_space::kReference>};
return &r;
}
TfLiteRegistration* Register_DEPTH_TO_SPACE_GENERIC_OPT() {
static TfLiteRegistration r = {
nullptr, nullptr, depth_to_space::Prepare,
depth_to_space::Eval<depth_to_space::kGenericOptimized>};
return &r;
}
TfLiteRegistration* Register_DEPTH_TO_SPACE() {
return Register_DEPTH_TO_SPACE_GENERIC_OPT();
} // namespace
TfLiteRegistration Register_DEPTH_TO_SPACE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace builtin
} // namespace ops
} // namespace tflite
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,97 +12,298 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>
#include <type_traits>
#include <initializer_list>
#include <vector>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
class DepthToSpaceOpModel : public SingleOpModel {
public:
DepthToSpaceOpModel(const TensorData& tensor_data, int block_size) {
input_ = AddInput(tensor_data);
output_ = AddOutput(tensor_data);
SetBuiltinOp(BuiltinOperator_DEPTH_TO_SPACE,
BuiltinOptions_DepthToSpaceOptions,
CreateDepthToSpaceOptions(builder_, block_size).Union());
BuildInterpreter({GetShape(input_)});
}
constexpr int kOutputDimsCount = 4;
struct DepthToSpaceTestParams {
int block_size;
// output_dims_data is a TfLiteIntArray
int output_dims_data[kOutputDimsCount + 1] = {kOutputDimsCount, 0, 0, 0, 0};
};
void ExecuteDepthToSpaceTest(const DepthToSpaceTestParams& params,
TfLiteTensor* tensors, int tensors_count) {
constexpr int kInputArrayData[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
constexpr int kOutputArrayData[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
TfLiteDepthToSpaceParams op_params = {};
op_params.block_size = params.block_size;
template <typename T>
void SetInput(std::initializer_list<T> data) {
PopulateTensor<T>(input_, data);
const TfLiteRegistration registration = tflite::Register_DEPTH_TO_SPACE();
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
outputs_array, static_cast<void*>(&op_params));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
template <typename T>
void TestDepthToSpace(const DepthToSpaceTestParams& params,
const int* input_dims_data, const T* input_data,
const int* expected_dims_data, const T* expected_data,
T* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* expected_dims = IntArrayFromInts(expected_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(params.output_dims_data);
const int expected_count = ElementCount(*expected_dims);
TfLiteTensor tensors[] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
};
constexpr int tensors_count = std::extent<decltype(tensors)>::value;
ExecuteDepthToSpaceTest(params, tensors, tensors_count);
constexpr float kTolerance = 1e-5;
for (int i = 0; i < expected_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance);
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
for (int i = 0; i < expected_dims->size; i++) {
// output dims will have been relocated during prepare phase,
// so use the tensor dims pointer.
TF_LITE_MICRO_EXPECT_EQ(expected_dims->data[i], tensors[1].dims->data[i]);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
}
private:
int input_;
int output_;
// min/max are used to compute scale, zero-point, compare tolerance
template <typename T, int kOutputSize>
struct TestQuantParams {
float data_min; // input and output data minimum value
float data_max; // input and output data maximum value
T input_data[kOutputSize]; // quantized input storage
T output_data[kOutputSize]; // quantized output storage
};
#ifdef GTEST_HAS_DEATH_TEST
TEST(DepthToSpaceOpModel, BadBlockSize) {
EXPECT_DEATH(DepthToSpaceOpModel({TensorType_FLOAT32, {1, 1, 1, 4}}, 4),
"Cannot allocate tensors");
// for quantized, the error shouldn't exceed step
template <typename T>
float GetTolerance(float min, float max) {
float kQuantizedStep =
2.0f * (max - min) /
(std::numeric_limits<T>::max() - std::numeric_limits<T>::min());
return kQuantizedStep;
}
#endif
TEST(DepthToSpaceOpModel, Float32) {
DepthToSpaceOpModel m({TensorType_FLOAT32, {1, 1, 1, 4}}, 2);
m.SetInput<float>({1.4, 2.3, 3.2, 4.1});
m.Invoke();
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({1.4, 2.3, 3.2, 4.1}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 1));
template <typename T, int kOutputSize>
void TestDepthToSpaceQuantized(const DepthToSpaceTestParams& params,
TestQuantParams<T, kOutputSize>* quant_params,
const int* input_dims_data,
const float* input_data,
const int* expected_dims_data,
const float* expected_data, float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* expected_dims = IntArrayFromInts(expected_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(params.output_dims_data);
const float scale =
ScaleFromMinMax<T>(quant_params->data_min, quant_params->data_max);
const int zero_point =
ZeroPointFromMinMax<T>(quant_params->data_min, quant_params->data_max);
TfLiteTensor tensors[] = {
CreateQuantizedTensor(input_data, quant_params->input_data, input_dims,
scale, zero_point),
CreateQuantizedTensor(quant_params->output_data, output_dims, scale,
zero_point),
};
constexpr int kTensorsCount = std::extent<decltype(tensors)>::value;
ExecuteDepthToSpaceTest(params, tensors, kTensorsCount);
Dequantize(quant_params->output_data, kOutputSize, scale, zero_point,
output_data);
const float kTolerance =
GetTolerance<T>(quant_params->data_min, quant_params->data_max);
for (int i = 0; i < kOutputSize; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], kTolerance);
}
for (int i = 0; i < expected_dims->size; i++) {
// output dims will have been relocated during prepare phase,
// so use the tensor dims pointer.
TF_LITE_MICRO_EXPECT_EQ(expected_dims->data[i], tensors[1].dims->data[i]);
}
}
TEST(DepthToSpaceOpModel, Uint8) {
DepthToSpaceOpModel m({TensorType_UINT8, {1, 1, 2, 4}}, 2);
m.SetInput<uint8_t>({1, 2, 3, 4, 5, 6, 7, 8});
m.Invoke();
EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 5, 6, 3, 4, 7, 8}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 4, 1));
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(DepthToSpaceOpModelFloat32_1114_2) {
constexpr int kInputDims[] = {4, 1, 1, 1, 4};
constexpr float kInput[] = {1.4, 2.3, 3.2, 4.1};
constexpr int kExpectDims[] = {4, 1, 2, 2, 1};
constexpr float kExpect[] = {1.4, 2.3, 3.2, 4.1};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestDepthToSpace(params, kInputDims, kInput, kExpectDims,
kExpect, output_data);
}
TEST(DepthToSpaceOpModel, int8) {
DepthToSpaceOpModel m({TensorType_INT8, {1, 2, 1, 4}}, 2);
m.SetInput<int8_t>({1, 2, 3, 4, 5, 6, 7, 8});
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 4, 2, 1));
TF_LITE_MICRO_TEST(DepthToSpaceOpModelFloat32_1124_2) {
constexpr int kInputDims[] = {4, 1, 1, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kExpectDims[] = {4, 1, 2, 4, 1};
constexpr float kExpect[] = {1, 2, 5, 6, 3, 4, 7, 8};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestDepthToSpace(params, kInputDims, kInput, kExpectDims,
kExpect, output_data);
}
TEST(DepthToSpaceOpModel, Int32) {
DepthToSpaceOpModel m({TensorType_INT32, {1, 2, 2, 4}}, 2);
m.SetInput<int32_t>({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
m.Invoke();
EXPECT_THAT(m.GetOutput<int32_t>(),
ElementsAreArray(
{1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 4, 4, 1));
TF_LITE_MICRO_TEST(DepthToSpaceOpModelFloat32_1214_2) {
constexpr int kInputDims[] = {4, 1, 2, 1, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kExpectDims[] = {4, 1, 4, 2, 1};
constexpr float kExpect[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestDepthToSpace(params, kInputDims, kInput, kExpectDims,
kExpect, output_data);
}
TEST(DepthToSpaceOpModel, Int64) {
DepthToSpaceOpModel m({TensorType_INT64, {1, 1, 1, 1}}, 1);
m.SetInput<int64_t>({4});
m.Invoke();
EXPECT_THAT(m.GetOutput<int64_t>(), ElementsAreArray({4}));
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 1));
TF_LITE_MICRO_TEST(DepthToSpaceOpModelFloat32_1224_2) {
constexpr int kInputDims[] = {4, 1, 2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
constexpr int kExpectDims[] = {4, 1, 4, 4, 1};
constexpr float kExpect[] = {1, 2, 5, 6, 3, 4, 7, 8,
9, 10, 13, 14, 11, 12, 15, 16};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestDepthToSpace(params, kInputDims, kInput, kExpectDims,
kExpect, output_data);
}
} // namespace
} // namespace tflite
TF_LITE_MICRO_TEST(DepthToSpaceOpModelFloat32_1111_1) {
constexpr int kInputDims[] = {4, 1, 1, 1, 1};
constexpr float kInput[] = {4};
constexpr int kExpectDims[] = {4, 1, 1, 1, 1};
constexpr float kExpect[] = {4};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 1;
tflite::testing::TestDepthToSpace(params, kInputDims, kInput, kExpectDims,
kExpect, output_data);
}
TF_LITE_MICRO_TEST(DepthToSpaceOpModelInt8_1114_2) {
constexpr int kInputDims[] = {4, 1, 1, 1, 4};
constexpr float kInput[] = {1.4, 2.3, 3.2, 4.1};
constexpr int kExpectDims[] = {4, 1, 2, 2, 1};
constexpr float kExpect[] = {1.4, 2.3, 3.2, 4.1};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestQuantParams<int8_t, kOutputCount> quant_params = {};
quant_params.data_min = 0.0;
quant_params.data_max = 5.0;
tflite::testing::TestDepthToSpaceQuantized<int8_t, kOutputCount>(
params, &quant_params, kInputDims, kInput, kExpectDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(DepthToSpaceOpModelInt8_1124_2) {
constexpr int kInputDims[] = {4, 1, 1, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kExpectDims[] = {4, 1, 2, 4, 1};
constexpr float kExpect[] = {1, 2, 5, 6, 3, 4, 7, 8};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestQuantParams<int8_t, kOutputCount> quant_params = {};
quant_params.data_min = 0.0;
quant_params.data_max = 9.0;
tflite::testing::TestDepthToSpaceQuantized<int8_t, kOutputCount>(
params, &quant_params, kInputDims, kInput, kExpectDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(DepthToSpaceOpModelInt8_1214_2) {
constexpr int kInputDims[] = {4, 1, 2, 1, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kExpectDims[] = {4, 1, 4, 2, 1};
constexpr float kExpect[] = {1, 2, 3, 4, 5, 6, 7, 8};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestQuantParams<int8_t, kOutputCount> quant_params = {};
quant_params.data_min = 0.0;
quant_params.data_max = 9.0;
tflite::testing::TestDepthToSpaceQuantized<int8_t, kOutputCount>(
params, &quant_params, kInputDims, kInput, kExpectDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(DepthToSpaceOpModelInt8_1224_2) {
constexpr int kInputDims[] = {4, 1, 2, 2, 4};
constexpr float kInput[] = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
constexpr int kExpectDims[] = {4, 1, 4, 4, 1};
constexpr float kExpect[] = {1, 2, 5, 6, 3, 4, 7, 8,
9, 10, 13, 14, 11, 12, 15, 16};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 2;
tflite::testing::TestQuantParams<int8_t, kOutputCount> quant_params = {};
quant_params.data_min = 0.0;
quant_params.data_max = 17.0;
tflite::testing::TestDepthToSpaceQuantized<int8_t, kOutputCount>(
params, &quant_params, kInputDims, kInput, kExpectDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(DepthToSpaceOpModelInt8_1111_1) {
constexpr int kInputDims[] = {4, 1, 1, 1, 1};
constexpr float kInput[] = {4};
constexpr int kExpectDims[] = {4, 1, 1, 1, 1};
constexpr float kExpect[] = {4};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
tflite::testing::DepthToSpaceTestParams params;
params.block_size = 1;
tflite::testing::TestQuantParams<int8_t, kOutputCount> quant_params = {};
quant_params.data_min = 3.0;
quant_params.data_max = 5.0;
tflite::testing::TestDepthToSpaceQuantized<int8_t, kOutputCount>(
params, &quant_params, kInputDims, kInput, kExpectDims, kExpect,
output_data);
}
TF_LITE_MICRO_TESTS_END
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
#include <cstddef>
#include <cstdint>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
namespace tflite {
namespace {
// used only with quantized data
struct LogSoftmaxOpData {
int32_t input_multiplier;
int32_t input_left_shift;
int32_t reverse_scaling_divisor;
int32_t reverse_scaling_right_shift;
int diff_min;
size_t outer_size; // number of tensor elements skipping computation axis
size_t depth; // number of tensor elements on computation axis
};
// input/output tensor index
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE(context, HaveSameShapes(input, output));
if (input->type == kTfLiteInt8) {
node->user_data =
context->AllocatePersistentBuffer(context, sizeof(LogSoftmaxOpData));
auto data = static_cast<LogSoftmaxOpData*>(node->user_data);
// quantization datum
constexpr int32_t kOutputZeroPoint = 127;
constexpr float kOutputScale = 16.0 / 256;
constexpr double kBeta = 1.0;
constexpr int kScaledDiffIntegerBits = 5;
TF_LITE_ENSURE(context, output->params.scale == kOutputScale);
TF_LITE_ENSURE(context, output->params.zero_point == kOutputZeroPoint);
int input_left_shift;
int reverse_scaling_right_shift;
tflite::PreprocessLogSoftmaxScalingExp(
kBeta, static_cast<double>(input->params.scale), kScaledDiffIntegerBits,
&data->input_multiplier, &input_left_shift,
&data->reverse_scaling_divisor, &reverse_scaling_right_shift);
data->input_left_shift = static_cast<int32_t>(input_left_shift);
data->reverse_scaling_right_shift =
static_cast<int32_t>(-reverse_scaling_right_shift);
// diff_min has a negative value, and is used to limit the maximum magnitude
// of the diffs, which are <= 0.
data->diff_min =
-tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_left_shift);
RuntimeShape input_shape = GetTensorShape(input);
const int trailing_dim = input_shape.DimensionsCount() - 1;
data->outer_size =
static_cast<size_t>(FlatSizeSkipDim(input_shape, trailing_dim));
data->depth = static_cast<size_t>(input_shape.Dims(trailing_dim));
}
return kTfLiteOk;
}
TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
const LogSoftmaxOpData* data =
static_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32: {
SoftmaxParams op_params = {};
reference_ops::LogSoftmax(op_params, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
return kTfLiteOk;
}
case kTfLiteInt8: {
SoftmaxParams op_params = {};
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
op_params.diff_min = data->diff_min;
reference_ops::LogSoftmax(op_params, data->outer_size, data->depth,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(context,
"LOG_SOFTMAX only supports float32, int8, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}
} // namespace
TfLiteRegistration Register_LOG_SOFTMAX() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/LogSoftmaxPrepare,
/*invoke=*/LogSoftmaxEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <type_traits>
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
namespace tflite {
namespace testing {
namespace {
void ExecuteLogSoftmaxTest(int tensors_count, TfLiteTensor* tensors) {
constexpr int kInputArrayData[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData);
constexpr int kOutputArrayData[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData);
const TfLiteRegistration registration = tflite::Register_LOG_SOFTMAX();
micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array,
outputs_array, nullptr);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
template <typename T>
void TestLogSoftmax(const float tolerance, const int* input_dims_data,
const T* input_data, const int* expected_dims,
const T* expected_data, T* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
TfLiteTensor tensors[] = {
CreateTensor(input_data, input_dims),
CreateTensor(output_data, output_dims),
};
constexpr int kTensorsCount = std::extent<decltype(tensors)>::value;
ExecuteLogSoftmaxTest(kTensorsCount, tensors);
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i], tolerance);
}
}
// min/max are used to compute scale, zero-point
template <typename T>
struct TestLogSoftmaxParams {
// quantization parameters
float data_min; // input and output data minimum value
float data_max; // input and output data maximum value
T* input_data; // quantized input storage
T* output_data; // quantized output storage
float tolerance; // maximum compare difference
};
template <typename T>
void TestLogSoftmaxQuantized(const TestLogSoftmaxParams<T>& params,
const int* input_dims_data,
const float* input_data, const int* expected_dims,
const float* expected_data,
const T* expected_data_quantized,
float* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims);
const int output_count = ElementCount(*output_dims);
constexpr float kOutputScale = 16.0 / 256;
constexpr int kOutputZeroPoint = 127;
const float scale = ScaleFromMinMax<T>(params.data_min, params.data_max);
const int zero_point =
ZeroPointFromMinMax<T>(params.data_min, params.data_max);
TfLiteTensor tensors[] = {
CreateQuantizedTensor(input_data, params.input_data, input_dims, scale,
zero_point),
CreateQuantizedTensor(params.output_data, output_dims, kOutputScale,
kOutputZeroPoint),
};
constexpr int kTensorsCount = std::extent<decltype(tensors)>::value;
ExecuteLogSoftmaxTest(kTensorsCount, tensors);
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_EQ(expected_data_quantized[i], params.output_data[i]);
}
Dequantize(params.output_data, output_count, kOutputScale, kOutputZeroPoint,
output_data);
for (int i = 0; i < output_count; i++) {
TF_LITE_MICRO_EXPECT_NEAR(expected_data[i], output_data[i],
params.tolerance);
}
}
} // namespace
} // namespace testing
} // namespace tflite
TF_LITE_MICRO_TESTS_BEGIN
// This contains the same test values as the Softmax test, but reference answer
// generated via the following snippet of python:
// logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32)
// logits2 = tf.constant([[0,-6],[2,4],[3,-2],[10,1]], dtype=tf.float32)
// lsm1 = tf.nn.log_softmax(logits1)
// lsm2 = tf.nn.log_softmax(logits2)
// with tf.Session() as sess:
// print('lsm1', sess.run(lsm1))
// print('lsm2', sess.run(lsm2))
TF_LITE_MICRO_TEST(FloatActivationsOpTestLogSoftmax) {
constexpr int kDims1[] = {2, 2, 4};
constexpr float kInput[] = {
0, -6, 2, 4, 3, -2, 10, 1,
};
constexpr float kExpect1[] = {
-4.14297, -10.14297, -2.14297, -.142971, //
-7.00104, -12.00104, -.00104087, -9.00104, //
};
constexpr int kOutputCount = std::extent<decltype(kExpect1)>::value;
float output_data[kOutputCount];
constexpr float kTolerance = 1e-5;
tflite::testing::TestLogSoftmax(kTolerance, kDims1, kInput, kDims1, kExpect1,
output_data);
// Same input, but a different shape.
constexpr int kDims2[] = {2, 4, 2};
constexpr float kExpect2[] = {
-.00247565, -6.00247, -2.12692, -.126928,
-.00671534, -5.00671, -.000123374, -9.00012,
};
tflite::testing::TestLogSoftmax(kTolerance, kDims2, kInput, kDims2, kExpect2,
output_data);
}
TF_LITE_MICRO_TEST(LogSoftmaxOpTestSimpleTest) {
constexpr int kDims[] = {2, 2, 5};
constexpr float kInput[] = {
1.0, 2.0, 3.0, 4.0, 5.0, //
-1.0, -2.0, -3.0, -4.0, -5.0, //
};
constexpr float kExpect[] = {
-4.45191431, -3.45191431, -2.45191431, -1.45191443, -0.4519144, //
-0.4519144, -1.45191443, -2.45191431, -3.45191431, -4.45191431 //
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
constexpr float kTolerance = 1e-6;
tflite::testing::TestLogSoftmax(kTolerance, kDims, kInput, kDims, kExpect,
output_data);
}
TF_LITE_MICRO_TEST(QuantizedActivationsOpTestLogSoftmaxInt8) {
constexpr int kDims[] = {2, 2, 4};
constexpr float kInput[] = {
0, -6, 2, 4, 3, -2, 10, 1,
};
constexpr float kExpect[] = {
-4.14297, -10.14297, -2.14297, -.142971,
-7.00104, -12.00104, -.00104087, -9.00104,
};
constexpr int8_t kExpectQuantized[] = {
61, -36, 93, 125, 15, -65, 127, -16,
};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
// setup quantization storage and parameters
int8_t q_output_data[kOutputCount];
int8_t q_input_data[kOutputCount];
constexpr float kMin = -10;
constexpr float kMax = 10;
constexpr float kLogSoftmaxQuantizedTolerance = 0.06355;
tflite::testing::TestLogSoftmaxParams<int8_t> params = {};
params.data_min = kMin;
params.data_max = kMax;
params.input_data = q_input_data;
params.output_data = q_output_data;
params.tolerance = kLogSoftmaxQuantizedTolerance;
tflite::testing::TestLogSoftmaxQuantized(
params, kDims, kInput, kDims, kExpect, kExpectQuantized, output_data);
}
TF_LITE_MICRO_TEST(ExtraTestLogSoftmaxInt8) {
constexpr int kDims[] = {2, 3, 1};
constexpr float kInput[] = {0, -1, 1};
constexpr float kExpect[] = {0, 0, 0};
constexpr int8_t kExpectQuantized[] = {127, 127, 127};
constexpr int kOutputCount = std::extent<decltype(kExpect)>::value;
float output_data[kOutputCount];
// setup quantization storage and parameters
int8_t q_output_data[kOutputCount];
int8_t q_input_data[kOutputCount];
constexpr float kMin = -1;
constexpr float kMax = 1;
constexpr float kLogSoftmaxQuantizedTolerance = 0.06355;
tflite::testing::TestLogSoftmaxParams<int8_t> params = {};
params.data_min = kMin;
params.data_max = kMax;
params.input_data = q_input_data;
params.output_data = q_output_data;
params.tolerance = kLogSoftmaxQuantizedTolerance;
tflite::testing::TestLogSoftmaxQuantized(
params, kDims, kInput, kDims, kExpect, kExpectQuantized, output_data);
}
TF_LITE_MICRO_TESTS_END
......@@ -36,6 +36,7 @@ TfLiteRegistration Register_BATCH_TO_SPACE_ND();
TfLiteRegistration Register_CAST();
TfLiteRegistration Register_CONV_2D();
TfLiteRegistration Register_CUMSUM();
TfLiteRegistration Register_DEPTH_TO_SPACE();
TfLiteRegistration Register_DEPTHWISE_CONV_2D();
TfLiteRegistration Register_DIV();
TfLiteRegistration Register_ELU();
......@@ -46,6 +47,7 @@ TfLiteRegistration Register_FLOOR_DIV();
TfLiteRegistration Register_FLOOR_MOD();
TfLiteRegistration Register_L2_POOL_2D();
TfLiteRegistration Register_LEAKY_RELU();
TfLiteRegistration Register_LOG_SOFTMAX();
TfLiteRegistration Register_QUANTIZE();
TfLiteRegistration Register_SHAPE();
TfLiteRegistration Register_SOFTMAX();
......
......@@ -49,7 +49,6 @@ cc_library(
copts = micro_copts(),
deps = [
":memory_planner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_compatibility",
],
)
......
......@@ -297,8 +297,6 @@ size_t GreedyMemoryPlanner::GetMaximumMemorySize() {
while (entry) {
BufferRequirements* requirements =
&requirements_[entry->requirements_index];
// TODO(b/148246793): Update all size and offset variables types from
// int to size_t
const size_t current_size = entry->offset + requirements->size;
if (current_size > max_size) {
max_size = current_size;
......
......@@ -182,6 +182,11 @@ class MicroMutableOpResolver : public MicroOpResolver {
ParseCumsum);
}
TfLiteStatus AddDepthToSpace() {
return AddBuiltin(BuiltinOperator_DEPTH_TO_SPACE,
tflite::Register_DEPTH_TO_SPACE(), ParseDepthToSpace);
}
TfLiteStatus AddDepthwiseConv2D() {
return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D);
......
......@@ -37,11 +37,12 @@ class RecordingMicroInterpreter : public MicroInterpreter {
RecordingMicroInterpreter(const Model* model,
const MicroOpResolver& op_resolver,
uint8_t* tensor_arena, size_t tensor_arena_size,
ErrorReporter* error_reporter)
ErrorReporter* error_reporter,
MicroProfiler* profiler = nullptr)
: MicroInterpreter(model, op_resolver,
RecordingMicroAllocator::Create(
tensor_arena, tensor_arena_size, error_reporter),
error_reporter),
error_reporter, profiler),
recording_micro_allocator_(
static_cast<const RecordingMicroAllocator&>(allocator())) {}
......
......@@ -41,6 +41,13 @@ readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARG
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} OPTIMIZATION_LEVEL=-Os test
# We have had examples where tests pass with -Os but fail without it so we run
# the unit tests with and without -Os. See
# https://github.com/tensorflow/tensorflow/issues/48516 for one such issue.
readable_run make -f tensorflow/lite/micro/tools/make/Makefile clean
readable_run make -j8 -f tensorflow/lite/micro/tools/make/Makefile TARGET=${TARGET} test
# We use Renode differently when running the full test suite (make test) vs an
# individual test. So, we test only of the kernels individually as well to have
# both of the Renode variations be part of the CI.
......
......@@ -276,6 +276,7 @@ tensorflow/lite/micro/kernels/comparisons_test.cc \
tensorflow/lite/micro/kernels/concatenation_test.cc \
tensorflow/lite/micro/kernels/conv_test.cc \
tensorflow/lite/micro/kernels/cumsum_test.cc \
tensorflow/lite/micro/kernels/depth_to_space_test.cc \
tensorflow/lite/micro/kernels/depthwise_conv_test.cc \
tensorflow/lite/micro/kernels/dequantize_test.cc \
tensorflow/lite/micro/kernels/detection_postprocess_test.cc \
......@@ -294,6 +295,7 @@ tensorflow/lite/micro/kernels/l2_pool_2d_test.cc \
tensorflow/lite/micro/kernels/leaky_relu_test.cc \
tensorflow/lite/micro/kernels/logical_test.cc \
tensorflow/lite/micro/kernels/logistic_test.cc \
tensorflow/lite/micro/kernels/log_softmax_test.cc \
tensorflow/lite/micro/kernels/maximum_minimum_test.cc \
tensorflow/lite/micro/kernels/mul_test.cc \
tensorflow/lite/micro/kernels/neg_test.cc \
......@@ -337,6 +339,7 @@ tensorflow/lite/micro/kernels/concatenation.cc \
tensorflow/lite/micro/kernels/conv.cc \
tensorflow/lite/micro/kernels/conv_common.cc \
tensorflow/lite/micro/kernels/cumsum.cc \
tensorflow/lite/micro/kernels/depth_to_space.cc \
tensorflow/lite/micro/kernels/depthwise_conv.cc \
tensorflow/lite/micro/kernels/depthwise_conv_common.cc \
tensorflow/lite/micro/kernels/dequantize.cc \
......@@ -360,6 +363,7 @@ tensorflow/lite/micro/kernels/l2_pool_2d.cc \
tensorflow/lite/micro/kernels/leaky_relu.cc \
tensorflow/lite/micro/kernels/logical.cc \
tensorflow/lite/micro/kernels/logistic.cc \
tensorflow/lite/micro/kernels/log_softmax.cc \
tensorflow/lite/micro/kernels/maximum_minimum.cc \
tensorflow/lite/micro/kernels/mul.cc \
tensorflow/lite/micro/kernels/neg.cc \
......@@ -434,6 +438,7 @@ tensorflow/lite/kernels/internal/reference/comparisons.h \
tensorflow/lite/kernels/internal/reference/concatenation.h \
tensorflow/lite/kernels/internal/reference/conv.h \
tensorflow/lite/kernels/internal/reference/cumsum.h \
tensorflow/lite/kernels/internal/reference/depth_to_space.h \
tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h \
tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h \
tensorflow/lite/kernels/internal/reference/dequantize.h \
......@@ -458,6 +463,7 @@ tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h \
tensorflow/lite/kernels/internal/reference/integer_ops/transpose_conv.h \
tensorflow/lite/kernels/internal/reference/l2normalization.h \
tensorflow/lite/kernels/internal/reference/leaky_relu.h \
tensorflow/lite/kernels/internal/reference/log_softmax.h \
tensorflow/lite/kernels/internal/reference/maximum_minimum.h \
tensorflow/lite/kernels/internal/reference/mul.h \
tensorflow/lite/kernels/internal/reference/neg.h \
......
......@@ -39,7 +39,8 @@ PLATFORM_FLAGS = \
-mcoproc \
-DMAX_RFFT_PWR=9 \
-DMIN_RFFT_PWR=MAX_RFFT_PWR \
$(TARGET_ARCH_DEFINES)
$(TARGET_ARCH_DEFINES) \
-mlongcalls
ifeq ($(BUILD_TYPE), release)
PLATFORM_FLAGS += -Wno-unused-private-field
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册