未验证 提交 36cd2f85 编写于 作者: F felix-johnny 提交者: GitHub

CMSIS-NN Quantization specific registration for SVDF

int8 input data type registration is split into its own
function.

The PR is a work towards RFC
https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/docs/rfc/002_16x8_quantization_port.md

BUG=quantization specific registration for SVDF to reduce library size

Change-Id: I0a77d98cd3fa643b530386a6ca59e41fdf2b3ce9
上级 47e6d982
......@@ -188,10 +188,43 @@ TfLiteStatus EvalSvdf(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
TfLiteStatus EvalSvdfInt8(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
TFLITE_DCHECK(node->user_data != nullptr);
const OpDataSvdf& data = *(static_cast<const OpDataSvdf*>(node->user_data));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kSvdfInputTensor);
const TfLiteEvalTensor* weights_feature =
tflite::micro::GetEvalInput(context, node, kSvdfWeightsFeatureTensor);
const TfLiteEvalTensor* weights_time =
tflite::micro::GetEvalInput(context, node, kSvdfWeightsTimeTensor);
const TfLiteEvalTensor* bias =
(NumInputs(node) == 5)
? tflite::micro::GetEvalInput(context, node, kSvdfBiasTensor)
: nullptr;
TfLiteEvalTensor* activation_state = tflite::micro::GetMutableEvalInput(
context, node, kSvdfInputActivationStateTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kSvdfOutputTensor);
TFLITE_DCHECK((weights_time->type == kTfLiteInt8) ||
(weights_time->type == kTfLiteInt16));
// Because of the TODO mentioned below, the int16 weight data type is not
// split into a seperate registration.
// TODO(#523): remove 16-bit code when no longer needed.
return EvalIntegerSVDF(context, node, input, weights_feature, weights_time,
bias, params, activation_state, output, data);
}
} // namespace
TfLiteRegistration Register_SVDF() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdf);
}
TfLiteRegistration Register_SVDF_INT8() {
return tflite::micro::RegisterOp(Init, PrepareSvdf, EvalSvdfInt8);
}
} // namespace tflite
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -82,7 +82,7 @@ TfLiteStatus PrepareSvdf(TfLiteContext* context, TfLiteNode* node);
// (reference or optimized) must define this function.
TfLiteRegistration Register_SVDF();
#if defined(HEXAGON)
#if defined(HEXAGON) || defined(CMSIS_NN)
TfLiteRegistration Register_SVDF_INT8();
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册