未验证 提交 ff246bcc 编写于 作者: A Advait Jain 提交者: GitHub

Manual changes needed to fix the automated TF --> TFLM sync. (#1164)

* Updated from upstream TF

* Manually copied changes from http://cl/449226814 and http://cl/449192972

* fix missing include (erroneously removed in previous commit).

* fix the hexagon build.
上级 90d62ab6
......@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REDUCE_H_
#include <algorithm>
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
......@@ -44,6 +46,55 @@ inline bool IsFirstReduction(const int* index, const int num_axis,
namespace tflite {
enum ReduceType {
kSum,
kProd,
kMax,
kMin,
kAny,
kAll,
};
template <typename T>
struct SumOp {
inline T operator()(const T& a, const T& b) { return a + b; }
static constexpr T kNeutralElement = T(0);
};
template <typename T, typename U>
struct CastSumOp {
inline U operator()(const U& a, const T& b) { return a + static_cast<U>(b); }
static constexpr U kNeutralElement = U(0);
};
template <typename T>
struct ProdOp {
inline T operator()(const T& a, const T& b) { return a * b; }
static constexpr T kNeutralElement = T(1);
};
template <typename T>
struct MaxOp {
inline T operator()(const T& a, const T& b) { return (a > b) ? a : b; }
static constexpr T kNeutralElement = std::numeric_limits<T>::lowest();
};
template <typename T>
struct MinOp {
inline T operator()(const T& a, const T& b) { return (a < b) ? a : b; }
static constexpr T kNeutralElement = std::numeric_limits<T>::max();
};
struct AndOp {
inline bool operator()(bool a, bool b) { return a && b; }
static constexpr bool kNeutralElement = true;
};
struct OrOp {
inline bool operator()(bool a, bool b) { return a || b; }
static constexpr bool kNeutralElement = false;
};
namespace reference_ops {
// When the number of axis is zero, the reduction is simply a copy.
......@@ -60,12 +111,11 @@ void ReduceIsCopy(const T* input_data, const int* input_dims,
// A generic reduce method that can be used for reduce_sum, reduce_mean, etc.
// This method iterates through input data and reduce elements along the
// dimensions given in axis.
template <typename In, typename Out>
template <typename In, typename Out, typename Op>
inline bool Reduce(const In* input_data, const int* input_dims,
const int* output_dims, const int input_num_dims,
const int output_num_dims, const int* axis,
const int num_axis, int* input_iter,
Out reducer(Out current, const In in), Out* output_data) {
const int num_axis, int* input_iter, Out* output_data) {
// Reset input iterator.
for (int idx = 0; idx < input_num_dims; ++idx) {
input_iter[idx] = 0;
......@@ -77,7 +127,7 @@ inline bool Reduce(const In* input_data, const int* input_dims,
size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims,
input_iter, num_axis, axis);
output_data[output_offset] =
reducer(output_data[output_offset], input_data[input_offset]);
Op()(output_data[output_offset], input_data[input_offset]);
} while (NextIndex(input_num_dims, input_dims, input_iter));
return true;
}
......@@ -113,12 +163,28 @@ inline bool Reduce(const In* input_data, const int* input_dims,
return true;
}
// This method parses the input 'axis' to remove duplicates and handle negative
// values, and returns a valid 'out_axis'
// Bubble sort for sorting small inputs. std::sort may dynamically allocate
// memory so is not suitable for use in TFLM.
static void sort(int* input, int size) {
for (int i = 0; i < size - 1; ++i) {
for (int j = 0; j < size - i - 1; ++j) {
if (input[j] > input[j + 1]) {
std::swap(input[j], input[j + 1]);
}
}
}
}
// This method parses the input 'axis' to remove duplicates, handle negative
// values and remove redundant dimensions. It returns a valid 'out_axis' and
// 'shape_out' contains the flattened input shape. 'out_num_dims' contains the
// reduced number of dimensions.
inline bool ResolveAxis(const int num_dims, const int* axis,
const int64_t num_axis, int* out_axis,
int* out_num_axis) {
*out_num_axis = 0; // Just in case.
int* out_num_axis, const int* shape_in, int* shape_out,
int* out_num_dims) {
int num_out_axis = 0;
int dims_out = num_dims;
// Short-circuit axis resolution for scalars; the axis will go unused.
if (num_dims == 0) {
return true;
......@@ -129,22 +195,64 @@ inline bool ResolveAxis(const int num_dims, const int* axis,
// negative index 'n_idx' as: n_idx = p_idx-num_dims
// eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
TFLITE_DCHECK(current >= 0 && current < num_dims);
if (current < 0 || current >= num_dims) {
return false;
}
bool is_dup = false;
for (int j = 0; j < *out_num_axis; ++j) {
for (int j = 0; j < num_out_axis; ++j) {
if (out_axis[j] == current) {
is_dup = true;
break;
}
}
if (!is_dup) {
out_axis[*out_num_axis] = current;
*out_num_axis += 1;
out_axis[num_out_axis] = current;
num_out_axis += 1;
}
}
// If two or more adjacent dimensions are either reduced
// over or not, then the second and subsequent dimensions may be flattened.
memcpy(shape_out, shape_in, num_dims * sizeof(int));
if (num_out_axis > 0) {
sort(out_axis, num_out_axis);
int64_t j = num_out_axis - 1;
// true if the previous index is present in out_axis.
bool previous_here = (out_axis[j] == num_dims - 1);
if (previous_here) {
j -= 1;
}
for (int64_t i = num_dims - 2; i >= 0; --i) {
// true if the current index is present in out_axis.
bool current_here = j >= 0 ? (out_axis[j] == i) : false;
if (current_here == previous_here) {
shape_out[i] *= shape_out[i + 1];
for (int64_t k = i + 1; k + 1 < num_dims; ++k) {
shape_out[k] = shape_out[k + 1];
}
// All axis bigger than this need to be reduced by 1.
for (int64_t k = 0; k < num_out_axis; ++k) {
if (out_axis[k] > i) {
out_axis[k] -= 1;
}
}
if (current_here) {
for (int64_t k = j + 1; k + 1 < num_out_axis; ++k) {
out_axis[k] = out_axis[k + 1];
}
num_out_axis -= 1;
}
dims_out -= 1;
}
if (current_here) {
j -= 1;
}
previous_here = current_here;
}
}
*out_num_axis = num_out_axis;
*out_num_dims = dims_out;
return true;
}
......@@ -155,13 +263,9 @@ inline bool ReduceSumImpl(const In* input_data, const int* input_dims,
const int output_num_dims, const int* axis,
const int num_axis, int* input_iter,
Out* output_data) {
auto reducer = [](const Out current, const In in) -> Out {
const Out actual_in = static_cast<Out>(in);
return current + actual_in;
};
return Reduce<In, Out>(input_data, input_dims, output_dims, input_num_dims,
output_num_dims, axis, num_axis, input_iter, reducer,
output_data);
return Reduce<In, Out, CastSumOp<In, Out>>(
input_data, input_dims, output_dims, input_num_dims, output_num_dims,
axis, num_axis, input_iter, output_data);
}
template <typename T>
......@@ -184,39 +288,94 @@ inline bool InitTensorDataForReduce(const int* dims, const int num_dims,
}
// Computes the generic value (i.e., sum/max/min/prod) of elements across
// dimensions given in axis. It needs to pass in init_value and reducer.
// dimensions given in axis. It needs to pass in reducer.
template <typename T>
inline bool ReduceGeneric(const T* input_data, const int* input_dims,
const int input_num_dims, T* output_data,
const int* output_dims, const int output_num_dims,
const int* axis, const int64_t num_axis_dimensions,
bool keep_dims, int* temp_index, int* resolved_axis,
T init_value,
T reducer(const T current, const T in)) {
int* normalized_dims, ReduceType reduce_type) {
T init_value;
switch (reduce_type) {
case kProd:
init_value = ProdOp<T>::kNeutralElement;
break;
case kSum:
init_value = SumOp<T>::kNeutralElement;
break;
case kMin:
init_value = MinOp<T>::kNeutralElement;
break;
case kMax:
init_value = MaxOp<T>::kNeutralElement;
break;
case kAny:
init_value = OrOp::kNeutralElement;
break;
case kAll:
init_value = AndOp::kNeutralElement;
break;
default:
return false;
}
// Reset output data.
if (!InitTensorDataForReduce(output_dims, output_num_dims, init_value,
output_data)) {
return false;
}
// Resolve axis.
int num_resolved_axis = 0;
int normalized_num_dims = 0;
if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
&num_resolved_axis, input_dims, normalized_dims,
&normalized_num_dims)) {
return false;
}
// Return early when input shape has zero dim. This is done after initializing
// data for output tensor because there are cases that the input tensor is
// empty but output tensor is not. In that case, output tensor should be
// filled with init_value.
// filled with Op::kNeutralElement.
for (int i = 0; i < input_num_dims; ++i) {
if (input_dims[i] == 0) return true;
}
// Resolve axis.
int num_resolved_axis = 0;
if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
&num_resolved_axis)) {
return false;
switch (reduce_type) {
case kProd:
return Reduce<T, T, ProdOp<T>>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis,
temp_index, output_data);
case kSum:
return Reduce<T, T, SumOp<T>>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis,
temp_index, output_data);
case kMin:
return Reduce<T, T, MinOp<T>>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis,
temp_index, output_data);
case kMax:
return Reduce<T, T, MaxOp<T>>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis,
temp_index, output_data);
case kAll:
return Reduce<T, T, AndOp>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis, temp_index,
output_data);
case kAny:
return Reduce<T, T, OrOp>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims,
resolved_axis, num_resolved_axis, temp_index,
output_data);
default:
return false;
}
return Reduce<T, T>(input_data, input_dims, output_dims, input_num_dims,
output_num_dims, resolved_axis, num_resolved_axis,
temp_index, reducer, output_data);
}
// Computes the mean of elements across dimensions given in axis.
......@@ -227,7 +386,8 @@ inline bool Mean(const T* input_data, const int* input_dims,
const int input_num_dims, T* output_data,
const int* output_dims, const int output_num_dims,
const int* axis, const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis, U* temp_sum) {
int* temp_index, int* resolved_axis, int* normalized_dims,
U* temp_sum) {
if (num_axis_dimensions == 0) {
ReduceIsCopy(input_data, input_dims, input_num_dims, output_data);
return true;
......@@ -250,21 +410,23 @@ inline bool Mean(const T* input_data, const int* input_dims,
// Resolve axis.
int num_resolved_axis = 0;
int normalized_num_dims = 0;
if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
&num_resolved_axis)) {
&num_resolved_axis, input_dims, normalized_dims,
&normalized_num_dims)) {
return false;
}
if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
output_num_dims, resolved_axis, num_resolved_axis,
temp_index, temp_sum)) {
if (!ReduceSumImpl<T, U>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims, resolved_axis,
num_resolved_axis, temp_index, temp_sum)) {
return false;
}
// Calculate mean by dividing output_data by num of aggregated element.
size_t num_elements_in_axis = 1;
for (int idx = 0; idx < num_resolved_axis; ++idx) {
size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
// Overflow prevention.
if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
return false;
......@@ -388,15 +550,13 @@ inline void Mean(const tflite::MeanParams& op_params,
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
float input_scale, const int* input_dims,
const int input_num_dims, T* output_data,
int32_t output_zero_point, float output_scale,
const int* output_dims,
const int output_num_dims, const int* axis,
const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis, U* temp_sum,
bool compute_sum) {
inline bool QuantizedMeanOrSum(
const T* input_data, int32_t input_zero_point, float input_scale,
const int* input_dims, const int input_num_dims, T* output_data,
int32_t output_zero_point, float output_scale, const int* output_dims,
const int output_num_dims, const int* axis, const int num_axis_dimensions,
bool keep_dims, int* temp_index, int* resolved_axis, int* normalized_dims,
U* temp_sum, bool compute_sum) {
if (num_axis_dimensions == 0) {
ReduceIsCopy(input_data, input_dims, input_num_dims, output_data);
return true;
......@@ -435,21 +595,23 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
// Resolve axis.
int num_resolved_axis = 0;
int normalized_num_dims = 0;
if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
&num_resolved_axis)) {
&num_resolved_axis, input_dims, normalized_dims,
&normalized_num_dims)) {
return false;
}
if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
output_num_dims, resolved_axis, num_resolved_axis,
temp_index, temp_sum)) {
if (!ReduceSumImpl<T, U>(input_data, normalized_dims, output_dims,
normalized_num_dims, output_num_dims, resolved_axis,
num_resolved_axis, temp_index, temp_sum)) {
return false;
}
// Calculate mean by dividing output_data by num of aggregated element.
size_t num_elements_in_axis = 1;
for (int idx = 0; idx < num_resolved_axis; ++idx) {
size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
size_t current = static_cast<size_t>(normalized_dims[resolved_axis[idx]]);
// Overflow prevention.
if (current > (std::numeric_limits<size_t>::max() / num_elements_in_axis)) {
return false;
......@@ -486,22 +648,23 @@ inline bool QuantizedMeanOrSum(const T* input_data, int32_t input_zero_point,
}
template <typename T>
inline bool QuantizedReduceProd(const T* input_data, int32_t input_zero_point,
const RuntimeShape& input_shape, T* output_data,
int32_t output_zero_point,
const RuntimeShape& output_shape,
const int* axis,
const int64_t num_axis_dimensions,
bool keep_dims, int* temp_index,
int* resolved_axis, int32_t* temp_prod,
int32_t scaling_multiplier, int scaling_shift) {
inline bool QuantizedReduceProd(
const T* input_data, int32_t input_zero_point,
const RuntimeShape& input_shape, T* output_data, int32_t output_zero_point,
const RuntimeShape& output_shape, const int* axis,
const int64_t num_axis_dimensions, bool keep_dims, int* temp_index,
int* resolved_axis, int* normalized_dims, int32_t* temp_prod,
int32_t scaling_multiplier, int scaling_shift) {
const int32_t kMinValue = std::numeric_limits<T>::min();
const int32_t kMaxValue = std::numeric_limits<T>::max();
// Resolve axis.
int num_resolved_axis = 0;
int normalized_num_dims = 0;
if (!ResolveAxis(input_shape.DimensionsCount(), axis, num_axis_dimensions,
resolved_axis, &num_resolved_axis)) {
resolved_axis, &num_resolved_axis,
reinterpret_cast<const int*>(input_shape.DimsData()),
normalized_dims, &normalized_num_dims)) {
return false;
}
......@@ -516,11 +679,10 @@ inline bool QuantizedReduceProd(const T* input_data, int32_t input_zero_point,
scaling_shift);
};
if (!Reduce<T, int32_t>(
input_data, input_shape.DimsData(), output_shape.DimsData(),
input_shape.DimensionsCount(), output_shape.DimensionsCount(),
resolved_axis, num_resolved_axis, temp_index, reducer_first,
reducer_next, temp_prod)) {
if (!Reduce<T, int32_t>(input_data, normalized_dims, output_shape.DimsData(),
normalized_num_dims, output_shape.DimensionsCount(),
resolved_axis, num_resolved_axis, temp_index,
reducer_first, reducer_next, temp_prod)) {
return false;
}
......
......@@ -31,7 +31,9 @@ struct OpDataReduce {
int32_t multiplier;
int shift;
int temp_buffer_idx;
int temp_index_idx;
int resolved_axis_idx;
int normalized_dims_idx;
int input_zp;
float input_scale;
int output_zp;
......
......@@ -80,10 +80,15 @@ TfLiteStatus PrepareMaxHelper(TfLiteContext* context, TfLiteNode* node,
context->RequestScratchBufferInArena(context, sizeof(int) * input->dims->size,
&op_data->temp_buffer_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*input->dims)),
&op_data->normalized_dims_idx);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
micro_context->DeallocateTempTfLiteTensor(axis);
......@@ -94,6 +99,7 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
OpDataReduce* op_data) {
MicroContext* micro_context = GetMicroContext(context);
TfLiteTensor* input = micro_context->AllocateTempInputTensor(node, 0);
TfLiteTensor* axis = micro_context->AllocateTempInputTensor(node, 1);
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, 0);
if (input->type == kTfLiteInt8 || input->type == kTfLiteInt16) {
const double real_multiplier = static_cast<double>(input->params.scale) /
......@@ -111,11 +117,23 @@ TfLiteStatus PrepareMeanOrSumHelper(TfLiteContext* context, TfLiteNode* node,
op_data->output_scale = output->params.scale;
}
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_index_idx);
context->RequestScratchBufferInArena(context, output_size * sizeof(int32_t),
&op_data->temp_buffer_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*axis->dims)),
&op_data->resolved_axis_idx);
context->RequestScratchBufferInArena(
context, sizeof(int) * static_cast<int>(ElementCount(*input->dims)),
&op_data->normalized_dims_idx);
TF_LITE_ENSURE_OK(
context,
PrepareSimple(context, node, &(op_data->multiplier), &(op_data->shift)));
// TODO(b/144955155): Support uint8_t(b/144955155) and int8_t(b/144955018)
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(axis);
micro_context->DeallocateTempTfLiteTensor(output);
return kTfLiteOk;
}
......@@ -141,8 +159,14 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
int num_axis = static_cast<int>(ElementCount(*axis->dims));
int temp_index[kMaxNumberOfAxis];
int resolved_axis[kMaxNumberOfReducedAxis];
int* temp_index = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_index_idx));
int* temp_buffer = static_cast<int*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
int* normalized_dims = static_cast<int*>(
context->GetScratchBuffer(context, op_data->normalized_dims_idx));
tflite::MeanParams op_params;
ResolveAxis(tflite::micro::GetTensorData<int>(axis), num_axis, &op_params);
......@@ -169,7 +193,7 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis,
params->keep_dims, temp_index, resolved_axis, normalized_dims,
tflite::micro::GetTensorData<float>(output)));
}
} break;
......@@ -184,19 +208,16 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int8_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_index, resolved_axis, temp_buffer));
reference_ops::Mean(tflite::micro::GetTensorData<int8_t>(input),
input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index,
resolved_axis, normalized_dims, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
......@@ -206,7 +227,7 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
normalized_dims, temp_buffer, false));
}
} break;
case kTfLiteInt16: {
......@@ -220,8 +241,6 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
tflite::micro::GetTensorData<int16_t>(output), op_data->output_zp);
} else if (op_data->input_zp == op_data->output_zp &&
op_data->input_scale == op_data->output_scale) {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::Mean(tflite::micro::GetTensorData<int16_t>(input),
......@@ -230,10 +249,8 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index,
resolved_axis, temp_buffer));
resolved_axis, normalized_dims, temp_buffer));
} else {
int32_t* temp_buffer = static_cast<int32_t*>(
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
TF_LITE_ENSURE(
context,
reference_ops::QuantizedMeanOrSum(
......@@ -243,7 +260,7 @@ TfLiteStatus EvalMeanHelper(TfLiteContext* context, TfLiteNode* node,
op_data->output_zp, op_data->output_scale, output->dims->data,
output->dims->size, tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_index, resolved_axis,
temp_buffer, false));
normalized_dims, temp_buffer, false));
}
} break;
default:
......@@ -269,37 +286,32 @@ TfLiteStatus EvalMaxHelper(TfLiteContext* context, TfLiteNode* node,
context->GetScratchBuffer(context, op_data->temp_buffer_idx));
int* resolved_axis = static_cast<int*>(
context->GetScratchBuffer(context, op_data->resolved_axis_idx));
int* normalized_dims = static_cast<int*>(
context->GetScratchBuffer(context, op_data->normalized_dims_idx));
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<float>::lowest(),
[](const float current, const float in) -> float {
return (in > current) ? in : current;
}));
TF_LITE_ENSURE(context, reference_ops::ReduceGeneric<float>(
tflite::micro::GetTensorData<float>(input),
input->dims->data, input->dims->size,
tflite::micro::GetTensorData<float>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_buffer,
resolved_axis, normalized_dims, kMax));
break;
case kTfLiteInt8:
TF_LITE_ENSURE_EQ(context, static_cast<double>(op_data->input_scale),
static_cast<double>(op_data->output_scale));
TF_LITE_ENSURE_EQ(context, op_data->input_zp, op_data->output_zp);
TF_LITE_ENSURE(
context,
reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input), input->dims->data,
input->dims->size, tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis), num_axis,
params->keep_dims, temp_buffer, resolved_axis,
std::numeric_limits<int8_t>::lowest(),
[](const int8_t current, const int8_t in) -> int8_t {
return (in > current) ? in : current;
}));
TF_LITE_ENSURE(context, reference_ops::ReduceGeneric<int8_t>(
tflite::micro::GetTensorData<int8_t>(input),
input->dims->data, input->dims->size,
tflite::micro::GetTensorData<int8_t>(output),
output->dims->data, output->dims->size,
tflite::micro::GetTensorData<int>(axis),
num_axis, params->keep_dims, temp_buffer,
resolved_axis, normalized_dims, kMax));
break;
default:
MicroPrintf("Only float32 and int8 types are supported.");
......
......@@ -156,7 +156,6 @@ void TestReduceOpQuantized(int* input_dims_data, const float* input_data,
float output_scale, int output_zero_point,
const TfLiteRegistration& registration,
TfLiteReducerParams* params) {
// Convert dimesion arguments to TfLiteArrays
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* axis_dims = IntArrayFromInts(axis_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
......@@ -229,6 +228,61 @@ void TestMeanOpQuantized(int* input_dims_data, const float* input_data,
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(MeanFloatFlatten2ReduceDims) {
int input_shape[] = {3, 4, 3, 2};
int output_shape[] = {1, 4};
int axis_shape[] = {1, 2};
int32_t axis_data[] = {2, 1};
float input_data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
float output_data[] = {3.5, 9.5, 15.5, 21.5};
float actual_output_data[4];
TfLiteReducerParams params = {false};
tflite::testing::TestMeanFloatInput4D(input_shape, input_data, axis_shape,
axis_data, output_shape, output_data,
actual_output_data, &params);
}
TF_LITE_MICRO_TEST(MeanFloatFlatten2NonReduceDims) {
int input_shape[] = {3, 4, 3, 2};
int output_shape[] = {1, 4};
int axis_shape[] = {1, 1};
int32_t axis_data[] = {2};
float input_data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
float output_data[] = {1.5, 3.5, 5.5, 7.5, 9.5, 11.5,
13.5, 15.5, 17.5, 19.5, 21.5, 23.5};
float actual_output_data[12];
TfLiteReducerParams params = {false};
tflite::testing::TestMeanFloatInput4D(input_shape, input_data, axis_shape,
axis_data, output_shape, output_data,
actual_output_data, &params);
}
TF_LITE_MICRO_TEST(MeanFloatFlatten2MiddleDims) {
int input_shape[] = {4, 2, 2, 3, 2};
int output_shape[] = {2, 2, 2};
int axis_shape[] = {1, 2};
int32_t axis_data[] = {1, 2};
float input_data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
float output_data[] = {6, 7, 18, 19};
float actual_output_data[4];
TfLiteReducerParams params = {false};
tflite::testing::TestMeanFloatInput4D(input_shape, input_data, axis_shape,
axis_data, output_shape, output_data,
actual_output_data, &params);
}
TF_LITE_MICRO_TEST(MeanFloat2DKeepDims) {
float output_data[tflite::testing::kOutputElements2D];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册