未验证 提交 b0afe289 编写于 作者: M Måns Nilsson 提交者: GitHub

FullyConnected Quantization specific registration for CMSIS-NN fully connected int16 (#1026)

* Quantization specific registration for CMSIS-NN fully connected int16

Adds int16 support, int16 specific registration and unit test (ported
from TFL) for CMSIS-NN fully connected.

Change-Id: I03fed7eef1880c0796785791225e618322c51642

* Add int16 support to fully_connected reference kernel

* CMSIS-NN; Add PopulateCommonParams in FC

* Fix formatting
Co-authored-by: Nmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
上级 7ad32213
......@@ -39,6 +39,10 @@ struct OpData {
// Index to buffer for optimizations if applicable.
int buffer_idx;
int32_t batches;
int32_t accum_depth;
int32_t output_depth;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
......@@ -68,8 +72,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
const RuntimeShape filter_shape = GetTensorShape(filter);
const RuntimeShape output_shape = GetTensorShape(output);
const int filter_dim_count = filter_shape.DimensionsCount();
const int output_dim_count = output_shape.DimensionsCount();
cmsis_nn_dims filter_dims;
filter_dims.n = filter_shape.Dims(filter_dim_count - 1);
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_shape.Dims(output_dim_count - 1);
data->accum_depth = filter_shape.Dims(filter_dim_count - 1);
data->batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
data->output_depth = output_shape.Dims(output_dim_count - 1);
// Set buffer index to a reset value
data->buffer_idx = -1;
......@@ -77,51 +93,39 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, params->activation, input->type, input, filter, bias, output,
&(data->reference_op_data)));
if (input->type == kTfLiteInt8) {
int32_t buf_size = 0;
if (input->type == kTfLiteInt16) {
buf_size = arm_fully_connected_s16_get_buffer_size(&filter_dims);
} else if (input->type == kTfLiteInt8) {
const RuntimeShape input_shape = GetTensorShape(input);
const RuntimeShape filter_shape = GetTensorShape(filter);
const RuntimeShape output_shape = GetTensorShape(output);
const int filter_dim_count = filter_shape.DimensionsCount();
const int output_dim_count = output_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);
int32_t buf_size = 0;
if (output_dim_count > 2 && accum_depth % 4 == 0) {
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
const int output_depth = output_shape.Dims(output_dim_count - 1);
if (output_dim_count > 2 && data->accum_depth % 4 == 0) {
data->per_channel_output_multiplier =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, output_depth * sizeof(int32_t)));
context, data->output_depth * sizeof(int32_t)));
data->per_channel_output_shift =
static_cast<int32_t*>(context->AllocatePersistentBuffer(
context, output_depth * sizeof(int32_t)));
context, data->output_depth * sizeof(int32_t)));
cmsis_nn_dims input_dims;
input_dims.n = batches;
input_dims.n = data->batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = accum_depth;
input_dims.c = data->accum_depth;
buf_size = arm_convolve_1x1_s8_fast_get_buffer_size(&input_dims);
} else {
cmsis_nn_dims filter_dims;
filter_dims.n = filter_shape.Dims(filter_dim_count - 1);
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_shape.Dims(output_dim_count - 1);
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);
}
}
if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
} else {
data->buffer_idx = -1;
}
if (buf_size > 0) {
TF_LITE_ENSURE_STATUS(context->RequestScratchBufferInArena(
context, buf_size, &data->buffer_idx));
}
micro_context->DeallocateTempTfLiteTensor(output);
......@@ -134,63 +138,66 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const RuntimeShape input_shape = tflite::micro::GetTensorShape(input);
const int output_dim_count = output_shape.DimensionsCount();
TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);
const int output_depth = output_shape.Dims(output_dim_count - 1);
const RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter);
const int filter_dim_count = filter_shape.DimensionsCount();
const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
cmsis_nn_per_tensor_quant_params quant_params;
void PopulateCommonParams(TfLiteContext* context,
cmsis_nn_per_tensor_quant_params& quant_params,
cmsis_nn_dims& input_dims, cmsis_nn_dims& filter_dims,
cmsis_nn_dims& bias_dims, cmsis_nn_dims& output_dims,
cmsis_nn_context& ctx, const OpData& data) {
quant_params.multiplier = data.reference_op_data.output_multiplier;
quant_params.shift = data.reference_op_data.output_shift;
cmsis_nn_dims input_dims;
input_dims.n = batches;
input_dims.n = data.batches;
input_dims.h = 1;
input_dims.w = 1;
input_dims.c = accum_depth;
input_dims.c = data.accum_depth;
cmsis_nn_dims filter_dims;
filter_dims.n = accum_depth;
filter_dims.n = data.accum_depth;
filter_dims.h = 1;
filter_dims.w = 1;
filter_dims.c = output_depth;
filter_dims.c = data.output_depth;
cmsis_nn_dims bias_dims;
bias_dims.n = 1;
bias_dims.h = 1;
bias_dims.w = 1;
bias_dims.c = output_depth;
bias_dims.c = data.output_depth;
cmsis_nn_dims output_dims;
output_dims.n = batches;
output_dims.n = data.batches;
output_dims.h = 1;
output_dims.w = 1;
output_dims.c = output_depth;
output_dims.c = data.output_depth;
cmsis_nn_context ctx;
ctx.buf = nullptr;
ctx.size = 0;
if (data.buffer_idx > -1) {
ctx.buf = context->GetScratchBuffer(context, data.buffer_idx);
}
}
TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
const RuntimeShape output_shape = tflite::micro::GetTensorShape(output);
const int output_dim_count = output_shape.DimensionsCount();
TFLITE_DCHECK_GE(output_dim_count, 2);
TFLITE_DCHECK_LE(output_dim_count, 4);
cmsis_nn_per_tensor_quant_params quant_params;
cmsis_nn_dims input_dims;
cmsis_nn_dims filter_dims;
cmsis_nn_dims bias_dims;
cmsis_nn_dims output_dims;
cmsis_nn_context ctx;
PopulateCommonParams(context, quant_params, input_dims, filter_dims,
bias_dims, output_dims, ctx, data);
const int32_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int32_t>(bias) : nullptr;
if (output_dim_count > 2 && accum_depth % 4 == 0) {
if (output_dim_count > 2 && data.accum_depth % 4 == 0) {
cmsis_nn_conv_params conv_params;
conv_params.dilation.h = 1;
conv_params.dilation.w = 1;
......@@ -209,7 +216,7 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
per_channel_quant_params.shift =
const_cast<int32_t*>(data.per_channel_output_shift);
for (int i = 0; i < output_depth; i++) {
for (int i = 0; i < data.output_depth; i++) {
per_channel_quant_params.multiplier[i] = quant_params.multiplier;
per_channel_quant_params.shift[i] = quant_params.shift;
}
......@@ -242,6 +249,44 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
TfLiteStatus EvalQuantizedInt16(TfLiteContext* context, TfLiteNode* node,
const OpData& data,
const TfLiteEvalTensor* input,
const TfLiteEvalTensor* filter,
const TfLiteEvalTensor* bias,
TfLiteEvalTensor* output) {
cmsis_nn_per_tensor_quant_params quant_params;
cmsis_nn_dims input_dims;
cmsis_nn_dims filter_dims;
cmsis_nn_dims bias_dims;
cmsis_nn_dims output_dims;
cmsis_nn_context ctx;
PopulateCommonParams(context, quant_params, input_dims, filter_dims,
bias_dims, output_dims, ctx, data);
const int64_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias) : nullptr;
cmsis_nn_fc_params fc_params;
fc_params.input_offset = -data.reference_op_data.input_zero_point;
fc_params.output_offset = data.reference_op_data.output_zero_point;
fc_params.filter_offset = 0;
fc_params.activation.min = data.reference_op_data.output_activation_min;
fc_params.activation.max = data.reference_op_data.output_activation_max;
TF_LITE_ENSURE_EQ(
context,
arm_fully_connected_s16(
&ctx, &fc_params, &quant_params, &input_dims,
tflite::micro::GetTensorData<int16_t>(input), &filter_dims,
tflite::micro::GetTensorData<int8_t>(filter), &bias_dims, bias_data,
&output_dims, tflite::micro::GetTensorData<int16_t>(output)),
ARM_CMSIS_NN_SUCCESS);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TFLITE_DCHECK(node->builtin_data != nullptr);
const auto* params =
......@@ -280,6 +325,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalQuantizedInt8(context, node, data, input, filter, bias,
output);
}
case kTfLiteInt16: {
return EvalQuantizedInt16(context, node, data, input, filter, bias,
output);
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
......@@ -318,6 +367,29 @@ TfLiteStatus EvalInt8(TfLiteContext* context, TfLiteNode* node) {
return EvalQuantizedInt8(context, node, data, input, filter, bias, output);
}
TfLiteStatus EvalInt16(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kFullyConnectedInputTensor);
const TfLiteEvalTensor* filter =
tflite::micro::GetEvalInput(context, node, kFullyConnectedWeightsTensor);
const TfLiteEvalTensor* bias =
tflite::micro::GetEvalInput(context, node, kFullyConnectedBiasTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kFullyConnectedOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
const OpData& data = *(static_cast<const OpData*>(node->user_data));
// Checks in Prepare ensure input, output and filter types are all the same.
if (input->type != kTfLiteInt16) {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
return kTfLiteError;
}
return EvalQuantizedInt16(context, node, data, input, filter, bias, output);
}
} // namespace
TfLiteRegistration Register_FULLY_CONNECTED() {
......@@ -328,4 +400,8 @@ TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt8);
}
TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
return tflite::micro::RegisterOp(Init, Prepare, EvalInt16);
}
} // namespace tflite
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -55,10 +55,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = micro_context->AllocateTempOutputTensor(
node, kFullyConnectedOutputTensor);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
"Hybrid models are not supported on TFLite Micro.");
TF_LITE_ENSURE_OK(context, CalculateOpDataFullyConnected(
context, params->activation, input->type,
......@@ -126,6 +123,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
const int64_t* bias_data =
nullptr != bias ? tflite::micro::GetTensorData<int64_t>(bias)
: nullptr;
tflite::reference_integer_ops::FullyConnected(
FullyConnectedParamsQuantized(data),
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int16_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias), bias_data,
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
break;
}
default: {
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
......
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -81,6 +81,24 @@ inline TfLiteRegistration Register_FULLY_CONNECTED_INT8() {
}
#endif
#if defined(CMSIS_NN)
// Returns a TfLiteRegistration struct for kernel variant that only supports
// int16.
TfLiteRegistration Register_FULLY_CONNECTED_INT16();
#else
// Note that while this block gets used for both reference and optimized kernels
// that do not have any specialized implementations, the only goal here is to
// define fallback implementation that allow reference kernels to still be used
// from applications that call a more specific kernel variant.
inline TfLiteRegistration Register_FULLY_CONNECTED_INT16() {
return Register_FULLY_CONNECTED();
}
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_FULLY_CONNECTED_H_
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -316,16 +316,16 @@ TfLiteStatus TestFullyConnectedFloat(
}
#endif
template <typename T>
template <typename dataT, typename weightT, typename biasT>
TfLiteStatus TestFullyConnectedQuantized(
int* input_dims_data, const float* input_data, T* input_quantized,
int* input_dims_data, const float* input_data, dataT* input_quantized,
const float input_scale, const int input_zero_point, int* weights_dims_data,
const float* weights_data, T* weights_quantized, const float weights_scale,
const int weights_zero_point, int* bias_dims_data, const float* bias_data,
int32_t* bias_quantized, const float* golden, T* golden_quantized,
int* output_dims_data, const float output_scale,
const int output_zero_point, TfLiteFusedActivation activation,
T* output_data) {
const float* weights_data, weightT* weights_quantized,
const float weights_scale, const int weights_zero_point,
int* bias_dims_data, const float* bias_data, biasT* bias_quantized,
const float* golden, dataT* golden_quantized, int* output_dims_data,
const float output_scale, const int output_zero_point,
TfLiteFusedActivation activation, dataT* output_data) {
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
TfLiteIntArray* weights_dims = IntArrayFromInts(weights_dims_data);
TfLiteIntArray* bias_dims = IntArrayFromInts(bias_dims_data);
......@@ -435,6 +435,36 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8) {
kTfLiteOk);
}
#if !(defined(XTENSA) || defined(HEXAGON))
TF_LITE_MICRO_TEST(SimpleTestQuantizedInt16) {
const float input_scale = 128.0 / 65536;
const int input_zero_point = 0;
const float weights_scale = 1.0f;
const int weights_zero_point = 0;
const float output_scale = 128.0 / 65536;
const int output_zero_point = 0;
const float simple_golden[] = {24, 25, 26, 58, 59, 60};
int16_t input_quantized[tflite::testing::simple_input_size];
int8_t weights_quantized[tflite::testing::simple_weights_size];
int64_t bias_quantized[tflite::testing::simple_output_size];
int16_t golden_quantized[tflite::testing::simple_output_size];
int16_t output_data[tflite::testing::simple_output_size];
TF_LITE_MICRO_EXPECT_EQ(
tflite::testing::TestFullyConnectedQuantized(
tflite::testing::simple_input_dims,
tflite::testing::simple_input_data, input_quantized, input_scale,
input_zero_point, tflite::testing::simple_weights_dims,
tflite::testing::simple_weights_data, weights_quantized,
weights_scale, weights_zero_point, tflite::testing::simple_bias_dims,
tflite::testing::simple_bias_data, bias_quantized, simple_golden,
golden_quantized, tflite::testing::simple_output_dims, output_scale,
output_zero_point, kTfLiteActNone, output_data),
kTfLiteOk);
}
#endif
TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedInt8) {
const float input_scale = 1.0f;
const int input_zero_point = -1;
......@@ -582,7 +612,8 @@ TF_LITE_MICRO_TEST(SimpleTestQuantizedInt8NullBias) {
tflite::testing::simple_input_data, input_quantized, input_scale,
input_zero_point, tflite::testing::simple_weights_dims,
tflite::testing::simple_weights_data, weights_quantized,
weights_scale, weights_zero_point, nullptr, nullptr, nullptr,
weights_scale, weights_zero_point, nullptr, nullptr,
static_cast<int32_t*>(nullptr),
tflite::testing::simple_golden_null_bias, golden_quantized,
tflite::testing::simple_output_dims, output_scale, output_zero_point,
kTfLiteActNone, output_data),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册