未验证 提交 48bfdf68 编写于 作者: F felix-johnny 提交者: GitHub

CMSIS-NN quantization specific registration for ADD

int8 input data type registration is split into its own
function.

The PR is a work towards RFC
https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/docs/rfc/002_16x8_quantization_port.md

BUG=quantization specific registration for add to reduce library size
上级 4fe601b1
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params,
TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node);
// Generic must define registration function.
TfLiteRegistration Register_ADD();
#if defined(CMSIS_NN)
TfLiteRegistration Register_ADD_INT8();
TfLiteRegistration Register_ADD_INT16();
#else
// Fallback registration
inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); }
inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); }
#endif
} // namespace tflite
#endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -102,6 +102,91 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
return kTfLiteOk;
}
void UpdateOpParams(tflite::ArithmeticParams& op_params, const OpData* data) {
op_params.left_shift = data->left_shift;
op_params.input1_offset = data->input1_offset;
op_params.input1_multiplier = data->input1_multiplier;
op_params.input1_shift = data->input1_shift;
op_params.input2_offset = data->input2_offset;
op_params.input2_multiplier = data->input2_multiplier;
op_params.input2_shift = data->input2_shift;
op_params.output_offset = data->output_offset;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
SetActivationParams(data->output_activation_min, data->output_activation_max,
&op_params);
}
TfLiteStatus EvalAddQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
TfLiteAddParams* params, const OpData* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params;
UpdateOpParams(op_params, data);
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (need_broadcast) {
reference_integer_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else {
arm_elementwise_add_s8(
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorData<int8_t>(input2), op_params.input1_offset,
op_params.input1_multiplier, op_params.input1_shift,
op_params.input2_offset, op_params.input2_multiplier,
op_params.input2_shift, op_params.left_shift,
tflite::micro::GetTensorData<int8_t>(output), op_params.output_offset,
op_params.output_multiplier, op_params.output_shift,
op_params.quantized_activation_min, op_params.quantized_activation_max,
MatchingElementsSize(tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorShape(output)));
}
return kTfLiteOk;
}
TfLiteStatus EvalAddQuantizedInt16(TfLiteContext* context, TfLiteNode* node,
TfLiteAddParams* params, const OpData* data,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params;
UpdateOpParams(op_params, data);
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
if (need_broadcast) {
reference_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int16_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int16_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else {
reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int16_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int16_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output), false);
}
return kTfLiteOk;
}
void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteAddParams* params, const OpData* data,
const TfLiteEvalTensor* input1,
......@@ -132,68 +217,14 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteEvalTensor* input1,
const TfLiteEvalTensor* input2,
TfLiteEvalTensor* output) {
tflite::ArithmeticParams op_params;
op_params.left_shift = data->left_shift;
op_params.input1_offset = data->input1_offset;
op_params.input1_multiplier = data->input1_multiplier;
op_params.input1_shift = data->input1_shift;
op_params.input2_offset = data->input2_offset;
op_params.input2_multiplier = data->input2_multiplier;
op_params.input2_shift = data->input2_shift;
op_params.output_offset = data->output_offset;
op_params.output_multiplier = data->output_multiplier;
op_params.output_shift = data->output_shift;
SetActivationParams(data->output_activation_min, data->output_activation_max,
&op_params);
bool need_broadcast = reference_ops::ProcessBroadcastShapes(
tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2), &op_params);
switch (output->type) {
case kTfLiteInt8: {
if (need_broadcast) {
reference_integer_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int8_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
} else {
arm_elementwise_add_s8(
tflite::micro::GetTensorData<int8_t>(input1),
tflite::micro::GetTensorData<int8_t>(input2),
op_params.input1_offset, op_params.input1_multiplier,
op_params.input1_shift, op_params.input2_offset,
op_params.input2_multiplier, op_params.input2_shift,
op_params.left_shift, tflite::micro::GetTensorData<int8_t>(output),
op_params.output_offset, op_params.output_multiplier,
op_params.output_shift, op_params.quantized_activation_min,
op_params.quantized_activation_max,
MatchingElementsSize(tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorShape(output)));
}
EvalAddQuantizedInt8(context, node, params, data, input1, input2, output);
break;
}
case kTfLiteInt16: {
if (need_broadcast) {
reference_ops::BroadcastAdd4DSlow(
op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int16_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int16_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output));
} else {
reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1),
tflite::micro::GetTensorData<int16_t>(input1),
tflite::micro::GetTensorShape(input2),
tflite::micro::GetTensorData<int16_t>(input2),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int16_t>(output),
false);
}
EvalAddQuantizedInt16(context, node, params, data, input1, input2,
output);
break;
}
default:
......@@ -268,8 +299,56 @@ TfLiteStatus EvalAdd(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalAddInt8(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(output->type == kTfLiteInt8);
const OpData* data = static_cast<const OpData*>(node->user_data);
TF_LITE_ENSURE_OK(context, EvalAddQuantizedInt8(context, node, params, data,
input1, input2, output));
return kTfLiteOk;
}
TfLiteStatus EvalAddInt16(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
const TfLiteEvalTensor* input1 =
tflite::micro::GetEvalInput(context, node, kInputTensor1);
const TfLiteEvalTensor* input2 =
tflite::micro::GetEvalInput(context, node, kInputTensor2);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
TFLITE_DCHECK(node->user_data != nullptr);
TFLITE_DCHECK(output->type == kTfLiteInt16);
const OpData* data = static_cast<const OpData*>(node->user_data);
TF_LITE_ENSURE_OK(context, EvalAddQuantizedInt16(context, node, params, data,
input1, input2, output));
return kTfLiteOk;
}
TfLiteRegistration Register_ADD() {
return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAdd);
}
TfLiteRegistration Register_ADD_INT8() {
return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAddInt8);
}
TfLiteRegistration Register_ADD_INT16() {
return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAddInt16);
}
} // namespace tflite
......@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/compatibility.h"
#include "tensorflow/lite/micro/kernels/add.h"
#include "tensorflow/lite/micro/kernels/conv.h"
#include "tensorflow/lite/micro/kernels/depthwise_conv.h"
#include "tensorflow/lite/micro/kernels/ethosu.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册