From 48bfdf6824b6a7134c74727f549bf9ff3fed2018 Mon Sep 17 00:00:00 2001 From: felix-johnny <48442848+felix-johnny@users.noreply.github.com> Date: Sat, 13 Aug 2022 03:36:01 +0200 Subject: [PATCH] CMSIS-NN quantization specific registration for ADD int8 input data type registration is split into its own function. The PR is a work towards RFC https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/docs/rfc/002_16x8_quantization_port.md BUG=quantization specific registration for add to reduce library size --- tensorflow/lite/micro/kernels/add.h | 15 +- tensorflow/lite/micro/kernels/cmsis_nn/add.cc | 195 ++++++++++++------ .../lite/micro/micro_mutable_op_resolver.h | 1 + 3 files changed, 152 insertions(+), 59 deletions(-) diff --git a/tensorflow/lite/micro/kernels/add.h b/tensorflow/lite/micro/kernels/add.h index 88526153..e2e5d23b 100644 --- a/tensorflow/lite/micro/kernels/add.h +++ b/tensorflow/lite/micro/kernels/add.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -59,6 +59,19 @@ TfLiteStatus CalculateOpDataAdd(TfLiteContext* context, TfLiteAddParams* params, TfLiteStatus AddPrepare(TfLiteContext* context, TfLiteNode* node); +// Generic must define registration function. +TfLiteRegistration Register_ADD(); + +#if defined(CMSIS_NN) +TfLiteRegistration Register_ADD_INT8(); + +TfLiteRegistration Register_ADD_INT16(); +#else +// Fallback registration +inline TfLiteRegistration Register_ADD_INT8() { return Register_ADD(); } + +inline TfLiteRegistration Register_ADD_INT16() { return Register_ADD(); } +#endif } // namespace tflite #endif // TENSORFLOW_LITE_MICRO_KERNELS_ADD_H_ diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc index a205ff6c..a0e5fca0 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/add.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/add.cc @@ -1,4 +1,4 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,6 +102,91 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params, return kTfLiteOk; } +void UpdateOpParams(tflite::ArithmeticParams& op_params, const OpData* data) { + op_params.left_shift = data->left_shift; + op_params.input1_offset = data->input1_offset; + op_params.input1_multiplier = data->input1_multiplier; + op_params.input1_shift = data->input1_shift; + op_params.input2_offset = data->input2_offset; + op_params.input2_multiplier = data->input2_multiplier; + op_params.input2_shift = data->input2_shift; + op_params.output_offset = data->output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + SetActivationParams(data->output_activation_min, data->output_activation_max, + &op_params); +} + +TfLiteStatus EvalAddQuantizedInt8(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteEvalTensor* input1, + const TfLiteEvalTensor* input2, + TfLiteEvalTensor* output) { + tflite::ArithmeticParams op_params; + UpdateOpParams(op_params, data); + + bool need_broadcast = reference_ops::ProcessBroadcastShapes( + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorShape(input2), &op_params); + + if (need_broadcast) { + reference_integer_ops::BroadcastAdd4DSlow( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + arm_elementwise_add_s8( + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorData(input2), op_params.input1_offset, + op_params.input1_multiplier, op_params.input1_shift, + op_params.input2_offset, op_params.input2_multiplier, + op_params.input2_shift, op_params.left_shift, + tflite::micro::GetTensorData(output), op_params.output_offset, + op_params.output_multiplier, op_params.output_shift, + op_params.quantized_activation_min, op_params.quantized_activation_max, + MatchingElementsSize(tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorShape(output))); + } + + return kTfLiteOk; +} + +TfLiteStatus EvalAddQuantizedInt16(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, const OpData* data, + const TfLiteEvalTensor* input1, + const TfLiteEvalTensor* input2, + TfLiteEvalTensor* output) { + tflite::ArithmeticParams op_params; + UpdateOpParams(op_params, data); + + bool need_broadcast = reference_ops::ProcessBroadcastShapes( + tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorShape(input2), &op_params); + + if (need_broadcast) { + reference_ops::BroadcastAdd4DSlow( + op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + } else { + reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1), + tflite::micro::GetTensorData(input1), + tflite::micro::GetTensorShape(input2), + tflite::micro::GetTensorData(input2), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), false); + } + + return kTfLiteOk; +} + void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, const OpData* data, const TfLiteEvalTensor* input1, @@ -132,68 +217,14 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, const TfLiteEvalTensor* input1, const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) { - tflite::ArithmeticParams op_params; - op_params.left_shift = data->left_shift; - op_params.input1_offset = data->input1_offset; - op_params.input1_multiplier = data->input1_multiplier; - op_params.input1_shift = data->input1_shift; - op_params.input2_offset = data->input2_offset; - op_params.input2_multiplier = data->input2_multiplier; - op_params.input2_shift = data->input2_shift; - op_params.output_offset = data->output_offset; - op_params.output_multiplier = data->output_multiplier; - op_params.output_shift = data->output_shift; - SetActivationParams(data->output_activation_min, data->output_activation_max, - &op_params); - bool need_broadcast = reference_ops::ProcessBroadcastShapes( - tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorShape(input2), &op_params); - switch (output->type) { case kTfLiteInt8: { - if (need_broadcast) { - reference_integer_ops::BroadcastAdd4DSlow( - op_params, tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorData(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorData(input2), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - } else { - arm_elementwise_add_s8( - tflite::micro::GetTensorData(input1), - tflite::micro::GetTensorData(input2), - op_params.input1_offset, op_params.input1_multiplier, - op_params.input1_shift, op_params.input2_offset, - op_params.input2_multiplier, op_params.input2_shift, - op_params.left_shift, tflite::micro::GetTensorData(output), - op_params.output_offset, op_params.output_multiplier, - op_params.output_shift, op_params.quantized_activation_min, - op_params.quantized_activation_max, - MatchingElementsSize(tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorShape(output))); - } + EvalAddQuantizedInt8(context, node, params, data, input1, input2, output); break; } case kTfLiteInt16: { - if (need_broadcast) { - reference_ops::BroadcastAdd4DSlow( - op_params, tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorData(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorData(input2), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); - } else { - reference_ops::Add(op_params, tflite::micro::GetTensorShape(input1), - tflite::micro::GetTensorData(input1), - tflite::micro::GetTensorShape(input2), - tflite::micro::GetTensorData(input2), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - false); - } + EvalAddQuantizedInt16(context, node, params, data, input1, input2, + output); break; } default: @@ -268,8 +299,56 @@ TfLiteStatus EvalAdd(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalAddInt8(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(output->type == kTfLiteInt8); + const OpData* data = static_cast(node->user_data); + + TF_LITE_ENSURE_OK(context, EvalAddQuantizedInt8(context, node, params, data, + input1, input2, output)); + + return kTfLiteOk; +} + +TfLiteStatus EvalAddInt16(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + const TfLiteEvalTensor* input1 = + tflite::micro::GetEvalInput(context, node, kInputTensor1); + const TfLiteEvalTensor* input2 = + tflite::micro::GetEvalInput(context, node, kInputTensor2); + TfLiteEvalTensor* output = + tflite::micro::GetEvalOutput(context, node, kOutputTensor); + + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(output->type == kTfLiteInt16); + const OpData* data = static_cast(node->user_data); + + TF_LITE_ENSURE_OK(context, EvalAddQuantizedInt16(context, node, params, data, + input1, input2, output)); + + return kTfLiteOk; +} + TfLiteRegistration Register_ADD() { return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAdd); } +TfLiteRegistration Register_ADD_INT8() { + return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAddInt8); +} + +TfLiteRegistration Register_ADD_INT16() { + return tflite::micro::RegisterOp(InitAdd, PrepareAdd, EvalAddInt16); +} + } // namespace tflite diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 78cae3f4..1e005c5a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/micro/compatibility.h" +#include "tensorflow/lite/micro/kernels/add.h" #include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/micro/kernels/depthwise_conv.h" #include "tensorflow/lite/micro/kernels/ethosu.h" -- GitLab