unidirectional_sequence_lstm.cc 29.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Integer version of unidirectional sequence LSTM. Only the standard LSTM
// (defined in the keras LSTM layer, e.g., no peephole etc.) is supported here.
// Currently used by the 8 bits activation case only, except for fallbacks.

#include <algorithm>
#include <limits>

#include "Include/arm_nnfunctions.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/lstm_eval.h"
#include "tensorflow/lite/micro/kernels/lstm_shared.h"
#include "tensorflow/lite/micro/kernels/micro_tensor_utils.h"

namespace tflite {

namespace {

struct OpData {
  OpDataLSTM params_ref;
  cmsis_nn_lstm_params params_cmsis_nn;
};

/*Helper Functions*/
TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
    TfLiteContext* context, int32_t zero_point,
    const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
    int32_t** output) {
  if (weight_tensor == nullptr) {
    return kTfLiteOk;
  }

  const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
  TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
  const int row = weight_shape.Dims(0);
  const int col = weight_shape.Dims(1);
  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  *output = static_cast<int32_t*>(
      context->AllocatePersistentBuffer(context, row * sizeof(int32_t)));

  if (bias_tensor == nullptr) {
    memset(*output, 0, row * sizeof(int32_t));
  } else {
    const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
    memcpy(*output, bias, row * sizeof(int32_t));
  }

  if (zero_point != 0) {
    const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
    tflite::tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point,
                                                         row, col, *output);
  }
  return kTfLiteOk;
}

TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
                             const LstmTensors& lstm_tensors, OpData* op_data) {
  const TfLiteTensor* input = lstm_tensors.GetInternalTensor(kLstmInputTensor);
  const TfLiteTensor* output_state =
      lstm_tensors.GetInternalTensor(tflite::kLstmOutputStateTensor);

  TF_LITE_ENSURE(context, input->type == kTfLiteInt8);

  op_data->params_cmsis_nn.output_state_offset =
      output_state->params.zero_point;

  const TfLiteTensor* input_to_forget_weights =
      lstm_tensors.GetInternalTensor(kLstmInputToForgetWeightsTensor);
  const TfLiteTensor* input_to_input_weights =
      lstm_tensors.GetInternalTensor(kLstmInputToInputWeightsTensor);
  const TfLiteTensor* input_to_output_weights =
      lstm_tensors.GetInternalTensor(kLstmInputToOutputWeightsTensor);
  const TfLiteTensor* input_to_cell_weights =
      lstm_tensors.GetInternalTensor(kLstmInputToCellWeightsTensor);
  const TfLiteTensor* forget_gate_bias =
      lstm_tensors.GetInternalTensor(kLstmForgetGateBiasTensor);
  const TfLiteTensor* cell_state =
      lstm_tensors.GetInternalTensor(kLstmCellStateTensor);

  const TfLiteTensor* cell_gate_bias =
      lstm_tensors.GetInternalTensor(kLstmCellGateBiasTensor);
  const TfLiteTensor* output_gate_bias =
      lstm_tensors.GetInternalTensor(kLstmOutputGateBiasTensor);
  const TfLiteTensor* input_gate_bias =
      lstm_tensors.GetInternalTensor(kLstmInputGateBiasTensor);
  const TfLiteTensor* recurrent_to_forget_weights =
      lstm_tensors.GetInternalTensor(kLstmRecurrentToForgetWeightsTensor);
  const TfLiteTensor* recurrent_to_cell_weights =
      lstm_tensors.GetInternalTensor(kLstmRecurrentToCellWeightsTensor);
  const TfLiteTensor* recurrent_to_output_weights =
      lstm_tensors.GetInternalTensor(kLstmRecurrentToOutputWeightsTensor);
  const TfLiteTensor* recurrent_to_input_weights =
      lstm_tensors.GetInternalTensor(kLstmRecurrentToInputWeightsTensor);
  const TfLiteTensor* cell_to_output_weights =
      lstm_tensors.GetInternalTensor(kLstmCellToOutputWeightsTensor);
  const TfLiteTensor* forget_layer_norm_coefficients =
      lstm_tensors.GetInternalTensor(kLstmForgetLayerNormCoefficientsTensor);
  const TfLiteTensor* projection_weights =
      lstm_tensors.GetInternalTensor(kLstmProjectionWeightsTensor);

  const bool use_layer_norm = (forget_layer_norm_coefficients != nullptr);
  const bool use_peephole = (cell_to_output_weights != nullptr);
  const bool use_projection = (projection_weights != nullptr);
  const bool use_cifg = (input_to_input_weights == nullptr);
  const bool lstm_unsupported_config =
      use_layer_norm || use_peephole || use_projection || use_cifg;
  TFLITE_DCHECK(!lstm_unsupported_config);

  // Pre-calculate bias + zero_point * weight.
  int32_t* input_to_forget_effective_bias = nullptr;
  int32_t* recurrent_to_forget_effective_bias = nullptr;
  int32_t* input_to_cell_effective_bias = nullptr;
  int32_t* recurrent_to_cell_effective_bias = nullptr;
  int32_t* input_to_output_effective_bias = nullptr;
  int32_t* recurrent_to_output_effective_bias = nullptr;
  int32_t* input_to_input_effective_bias = nullptr;
  int32_t* recurrent_to_input_effective_bias = nullptr;

  const int32_t output_state_zero_point =
      -op_data->params_cmsis_nn.output_state_offset;
  const int32_t input_zero_point = -input->params.zero_point;

  TF_LITE_ENSURE_OK(context,
                    PrecomputeZeroPointTimesWeightWithBias(
                        context, input_zero_point, input_to_forget_weights,
                        forget_gate_bias, &input_to_forget_effective_bias));

  TF_LITE_ENSURE_OK(context, PrecomputeZeroPointTimesWeightWithBias(
                                 context, output_state_zero_point,
                                 recurrent_to_forget_weights, nullptr,
                                 &recurrent_to_forget_effective_bias));

  // Modulation gate.
  TF_LITE_ENSURE_OK(context,
                    PrecomputeZeroPointTimesWeightWithBias(
                        context, input_zero_point, input_to_cell_weights,
                        cell_gate_bias, &input_to_cell_effective_bias));
  TF_LITE_ENSURE_OK(
      context, PrecomputeZeroPointTimesWeightWithBias(
                   context, output_state_zero_point, recurrent_to_cell_weights,
                   nullptr, &recurrent_to_cell_effective_bias));

  // Output gate.
  TF_LITE_ENSURE_OK(context,
                    PrecomputeZeroPointTimesWeightWithBias(
                        context, input_zero_point, input_to_output_weights,
                        output_gate_bias, &input_to_output_effective_bias));

  TF_LITE_ENSURE_OK(context, PrecomputeZeroPointTimesWeightWithBias(
                                 context, output_state_zero_point,
                                 recurrent_to_output_weights, nullptr,
                                 &recurrent_to_output_effective_bias));

  // Input gate. The calculation is only meaningful for non-cifg case.
  TF_LITE_ENSURE_OK(context,
                    PrecomputeZeroPointTimesWeightWithBias(
                        context, input_zero_point, input_to_input_weights,
                        input_gate_bias, &input_to_input_effective_bias));
  TF_LITE_ENSURE_OK(
      context, PrecomputeZeroPointTimesWeightWithBias(
                   context, output_state_zero_point, recurrent_to_input_weights,
                   nullptr, &recurrent_to_input_effective_bias));

  op_data->params_cmsis_nn.i2f_effective_bias = input_to_forget_effective_bias;
  op_data->params_cmsis_nn.r2f_effective_bias =
      recurrent_to_forget_effective_bias;
  op_data->params_cmsis_nn.i2c_effective_bias = input_to_cell_effective_bias;
  op_data->params_cmsis_nn.r2c_effective_bias =
      recurrent_to_cell_effective_bias;
  op_data->params_cmsis_nn.i2o_effective_bias = input_to_output_effective_bias;
  op_data->params_cmsis_nn.r2o_effective_bias =
      recurrent_to_output_effective_bias;
  op_data->params_cmsis_nn.i2i_effective_bias = input_to_input_effective_bias;
  op_data->params_cmsis_nn.r2i_effective_bias =
      recurrent_to_input_effective_bias;

  // Get intermediate scales and zero points.
  float intermediate_scale[5];
  int32_t intermediate_zp[5];
  for (int i = 0; i < 4; ++i) {
    // Q3.12 for activation functions.
    intermediate_scale[i] = std::pow(2.0f, -12.0f);
    intermediate_zp[i] = 0;
  }

  MicroContext* micro_context = GetMicroContext(context);
  // In the absence of projection, hidden becomes otuput and this intermediate
  // is ignored.
  TfLiteTensor* hidden = micro_context->AllocateTempIntermediateTensor(node, 4);
  TF_LITE_ENSURE(context, hidden->quantization.type != kTfLiteNoQuantization);
  auto* hidden_params =
      static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
  intermediate_scale[4] = hidden_params->scale->data[0];
  intermediate_zp[4] = hidden_params->zero_point->data[0];
  if (hidden != nullptr) {
    micro_context->DeallocateTempTfLiteTensor(hidden);
  }

  // Scales.
  const float default_scale = 1.0;
  float input_scale = default_scale;
  float input_to_input_weight_scale = default_scale;
  float recurrent_to_input_weight_scale = default_scale;
  float input_to_forget_weight_scale = default_scale;
  float recurrent_to_forget_weight_scale = default_scale;
  float input_to_cell_weight_scale = default_scale;
  float recurrent_to_cell_weight_scale = default_scale;
  float input_to_output_weight_scale = default_scale;
  float recurrent_to_output_weight_scale = default_scale;
  float output_state_scale = default_scale;
  int cell_scale = 1;

  // Effective scales.
  float effective_input_to_input_scale = default_scale;
  float effective_recurrent_to_input_scale = default_scale;
  float effective_cell_to_input_scale = default_scale;
  float effective_input_to_forget_scale = default_scale;
  float effective_recurrent_to_forget_scale = default_scale;
  float effective_cell_to_forget_scale = default_scale;
  float effective_input_to_cell_scale = default_scale;
  float effective_recurrent_to_cell_scale = default_scale;
  float effective_input_to_output_scale = default_scale;
  float effective_recurrent_to_output_scale = default_scale;
  float effective_cell_to_output_scale = default_scale;
  float effective_hidden_scale = default_scale;

  // Populate scales.
  input_to_input_weight_scale = input_to_input_weights->params.scale;
  recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;

  output_state_scale = output_state->params.scale;

  input_to_forget_weight_scale = input_to_forget_weights->params.scale;
  input_to_cell_weight_scale = input_to_cell_weights->params.scale;
  input_to_output_weight_scale = input_to_output_weights->params.scale;
  recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
  recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
  recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;

  // Check cell state (already used above)
  TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
  TF_LITE_ENSURE(context, cell_scale <= -9);

  op_data->params_cmsis_nn.cell_state_shift = cell_scale;
  input_scale = input->params.scale;

  // Calculate effective scales.
  effective_input_to_input_scale =
      input_to_input_weight_scale * input_scale / intermediate_scale[0];
  effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
                                       output_state_scale /
                                       intermediate_scale[0];

  effective_input_to_forget_scale =
      input_to_forget_weight_scale * input_scale / intermediate_scale[1];
  effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
                                        output_state_scale /
                                        intermediate_scale[1];

  effective_input_to_cell_scale =
      input_to_cell_weight_scale * input_scale / intermediate_scale[2];
  effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
                                      output_state_scale /
                                      intermediate_scale[2];

  effective_input_to_output_scale =
      input_to_output_weight_scale * input_scale / intermediate_scale[3];
  effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
                                        output_state_scale /
                                        intermediate_scale[3];

  effective_hidden_scale =
      std::pow(2.0f, -15.0f) / intermediate_scale[4] * std::pow(2.0f, -15.0f);

  // Decompose scales.
  int shift_output;
  QuantizeMultiplier(
      static_cast<double>(effective_input_to_input_scale),
      &op_data->params_cmsis_nn.input_to_input_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.input_to_input_scaling.shift =
      static_cast<int32_t>(shift_output);

  QuantizeMultiplier(
      static_cast<double>(effective_recurrent_to_input_scale),
      &op_data->params_cmsis_nn.recurrent_to_input_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.recurrent_to_input_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(static_cast<double>(effective_cell_to_input_scale),
                     &op_data->params_cmsis_nn.cell_to_input_scaling.multiplier,
                     &shift_output);
  op_data->params_cmsis_nn.cell_to_input_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_input_to_forget_scale),
      &op_data->params_cmsis_nn.input_to_forget_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.input_to_forget_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_recurrent_to_forget_scale),
      &op_data->params_cmsis_nn.recurrent_to_forget_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.recurrent_to_forget_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_cell_to_forget_scale),
      &op_data->params_cmsis_nn.cell_to_forget_scaling.multiplier,
      &shift_output);
  // ok
  op_data->params_cmsis_nn.cell_to_forget_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(static_cast<double>(effective_input_to_cell_scale),
                     &op_data->params_cmsis_nn.input_to_cell_scaling.multiplier,
                     &shift_output);
  op_data->params_cmsis_nn.input_to_cell_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_recurrent_to_cell_scale),
      &op_data->params_cmsis_nn.recurrent_to_cell_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.recurrent_to_cell_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_input_to_output_scale),
      &op_data->params_cmsis_nn.input_to_output_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.input_to_output_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_recurrent_to_output_scale),
      &op_data->params_cmsis_nn.recurrent_to_output_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.recurrent_to_output_scaling.shift =
      static_cast<int32_t>(shift_output);
  QuantizeMultiplier(
      static_cast<double>(effective_cell_to_output_scale),
      &op_data->params_cmsis_nn.cell_to_output_scaling.multiplier,
      &shift_output);
  op_data->params_cmsis_nn.cell_to_output_scaling.shift =
      static_cast<int32_t>(shift_output);

  op_data->params_cmsis_nn.projection_scaling.shift =
      static_cast<int32_t>(shift_output);

  QuantizeMultiplier(static_cast<double>(effective_hidden_scale),
                     &op_data->params_cmsis_nn.hidden_scaling.multiplier,
                     &shift_output);
  op_data->params_cmsis_nn.hidden_scaling.shift =
      static_cast<int32_t>(shift_output);

  op_data->params_cmsis_nn.hidden_offset = intermediate_zp[4];

  op_data->params_cmsis_nn.activation.min = std::numeric_limits<int16_t>::min();
  op_data->params_cmsis_nn.activation.max = std::numeric_limits<int16_t>::max();

  return kTfLiteOk;
}

template <typename CellType>
TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
    const OpData& op_data, const LSTMKernelContents& kernel_content,
    const LSTMBuffers<CellType>& buffers) {
  const OpDataLSTM& op_data_lstm = op_data.params_ref;
  const TfLiteEvalTensor* input =
      kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
  const TfLiteEvalTensor* input_gate_bias =
      kernel_content.GetInternalTensor(tflite::kLstmInputGateBiasTensor);
  const TfLiteEvalTensor* forget_gate_bias =
      kernel_content.GetInternalTensor(tflite::kLstmForgetGateBiasTensor);
  const TfLiteEvalTensor* cell_gate_bias =
      kernel_content.GetInternalTensor(tflite::kLstmCellGateBiasTensor);
  const TfLiteEvalTensor* output_gate_bias =
      kernel_content.GetInternalTensor(tflite::kLstmOutputGateBiasTensor);
  const TfLiteEvalTensor* input_to_output_weights =
      kernel_content.GetInternalTensor(tflite::kLstmInputToOutputWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_output_weights =
      kernel_content.GetInternalTensor(
          tflite::kLstmRecurrentToOutputWeightsTensor);
  const TfLiteEvalTensor* input_to_input_weights =
      kernel_content.GetInternalTensor(tflite::kLstmInputToInputWeightsTensor);
  const TfLiteEvalTensor* input_to_forget_weights =
      kernel_content.GetInternalTensor(tflite::kLstmInputToForgetWeightsTensor);
  const TfLiteEvalTensor* input_to_cell_weights =
      kernel_content.GetInternalTensor(tflite::kLstmInputToCellWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_input_weights =
      kernel_content.GetInternalTensor(
          tflite::kLstmRecurrentToInputWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_forget_weights =
      kernel_content.GetInternalTensor(
          tflite::kLstmRecurrentToForgetWeightsTensor);
  const TfLiteEvalTensor* recurrent_to_cell_weights =
      kernel_content.GetInternalTensor(
          tflite::kLstmRecurrentToCellWeightsTensor);
  const TfLiteEvalTensor* cell_to_input_weights =
      kernel_content.GetInternalTensor(tflite::kLstmCellToInputWeightsTensor);
  const TfLiteEvalTensor* cell_to_forget_weights =
      kernel_content.GetInternalTensor(tflite::kLstmCellToForgetWeightsTensor);
  const TfLiteEvalTensor* cell_to_output_weights =
      kernel_content.GetInternalTensor(tflite::kLstmCellToOutputWeightsTensor);
  const TfLiteEvalTensor* cell_state =
      kernel_content.GetInternalTensor(tflite::kLstmCellStateTensor);
  const TfLiteEvalTensor* output_state =
      kernel_content.GetInternalTensor(tflite::kLstmOutputStateTensor);
  const TfLiteEvalTensor* output = kernel_content.output_tensor;

  TFLITE_DCHECK(input->dims->size >= 2 && input->dims->size <= 3);

  cmsis_nn_lstm_context scratch_buffers;
  scratch_buffers.input_gate = reinterpret_cast<int16_t*>(buffers.buffer0);
  scratch_buffers.forget_gate = reinterpret_cast<int16_t*>(buffers.buffer1);
  scratch_buffers.cell_gate = reinterpret_cast<int16_t*>(buffers.buffer2);
  scratch_buffers.output_gate = reinterpret_cast<int16_t*>(buffers.buffer3);

  cmsis_nn_lstm_params cmsis_lstm_params = op_data.params_cmsis_nn;
  cmsis_lstm_params.time_major = op_data_lstm.size_info.time_major;
  cmsis_lstm_params.clip.cell =
      op_data_lstm.cell_state_info.quantized_cell_clip;

  cmsis_lstm_params.input_gate_bias = const_cast<int32_t*>(
      tflite::micro::GetOptionalTensorData<int32_t>(input_gate_bias));
  cmsis_lstm_params.forget_gate_bias = const_cast<int32_t*>(
      tflite::micro::GetOptionalTensorData<int32_t>(forget_gate_bias));
  cmsis_lstm_params.cell_gate_bias = const_cast<int32_t*>(
      tflite::micro::GetOptionalTensorData<int32_t>(cell_gate_bias));
  cmsis_lstm_params.output_gate_bias = const_cast<int32_t*>(
      tflite::micro::GetOptionalTensorData<int32_t>(output_gate_bias));

  const bool time_major = op_data_lstm.size_info.time_major;
  const int n_input = input->dims->data[input->dims->size - 1];
  const int n_output = recurrent_to_output_weights->dims->data[1];

  int max_time, n_batch;
  if (input->dims->size == 2) {
    max_time = 1;
    n_batch = input->dims->data[0];
  } else {
    max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
    n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
  }

  cmsis_nn_lstm_dims lstm_dims;
  lstm_dims.num_inputs = n_input;
  lstm_dims.num_outputs = n_output;
  lstm_dims.num_batches = n_batch;
  lstm_dims.max_time = max_time;

  arm_lstm_unidirectional_s16_s8(
      &scratch_buffers,
      const_cast<int8_t*>(tflite::micro::GetTensorData<int8_t>(input)),
      &lstm_dims,
      const_cast<int8_t*>(
          tflite::micro::GetOptionalTensorData<int8_t>(input_to_input_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          input_to_forget_weights)),
      const_cast<int8_t*>(
          tflite::micro::GetOptionalTensorData<int8_t>(input_to_cell_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          input_to_output_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          recurrent_to_input_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          recurrent_to_forget_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          recurrent_to_cell_weights)),
      const_cast<int8_t*>(tflite::micro::GetOptionalTensorData<int8_t>(
          recurrent_to_output_weights)),
      const_cast<int16_t*>(
          tflite::micro::GetOptionalTensorData<int16_t>(cell_to_input_weights)),
      const_cast<int16_t*>(tflite::micro::GetOptionalTensorData<int16_t>(
          cell_to_forget_weights)),
      const_cast<int16_t*>(tflite::micro::GetOptionalTensorData<int16_t>(
          cell_to_output_weights)),
      nullptr, &cmsis_lstm_params,
      const_cast<int8_t*>(tflite::micro::GetTensorData<int8_t>(output_state)),
      const_cast<int16_t*>(tflite::micro::GetTensorData<int16_t>(cell_state)),
      const_cast<int8_t*>(tflite::micro::GetTensorData<int8_t>(output)));

  return kTfLiteOk;
}

/*Kernel functions*/

void* UnidirectionalSequenceLstmInit(TfLiteContext* context, const char* buffer,
                                     size_t length) {
  TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
  return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context,
                                               TfLiteNode* node) {
  TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
  TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);

  TFLITE_DCHECK(node->builtin_data != nullptr);
  TFLITE_DCHECK(node->user_data != nullptr);

  OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
  OpDataLSTM* op_data_lstm = &op_data->params_ref;

  const auto* builtin_data =
      static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
  // All TempTfLiteTensors will be deallocated through the destructor.
  LstmTensors lstm_tensors(context, node);
  TF_LITE_ENSURE_OK(context, lstm_tensors.ValidateTensorStatus(context));

  op_data_lstm->cell_gate_nonlinear_type = builtin_data->activation;
  op_data_lstm->size_info =
      CreateLstmSizeInfo(builtin_data->time_major,
                         lstm_tensors.GetInternalTensor(kLstmInputTensor)->dims,
                         lstm_tensors.HiddenStateTensor()->dims);

  const TfLiteTensor* input = lstm_tensors.GetInternalTensor(kLstmInputTensor);
  const auto activation_type = input->type;

  if (kTfLiteInt8 == activation_type) {
    TF_LITE_ENSURE_STATUS(
        CalculateOpData(context, node, lstm_tensors, op_data));
  }

  TF_LITE_ENSURE_OK(context, ValidateTensorSize(context, lstm_tensors,
                                                op_data_lstm->size_info));

  // Create cell state information and gate parameters (Fully Connected and Mul)
  auto cell_state_type =
      lstm_tensors.GetInternalTensor(kLstmCellStateTensor)->type;
  if (cell_state_type == kTfLiteFloat32) {
    op_data_lstm->cell_state_info =
        CreateLstmCellStateInfoFloat(builtin_data->cell_clip);
    TF_LITE_ENSURE_OK(context, PrepareGateParametersFloat(context, lstm_tensors,
                                                          op_data_lstm));
  } else if (cell_state_type == kTfLiteInt16) {
    op_data_lstm->cell_state_info = CreateLstmCellStateInfo(
        lstm_tensors.CellStateTensor()->params.scale, builtin_data->cell_clip);
    TF_LITE_ENSURE_OK(context, PrepareGateParametersInteger(
                                   context, lstm_tensors, op_data_lstm));
  } else {
    MicroPrintf(
        "Cell state type %s (%d) not supported. The quantized Unidirectional "
        "Sequence LSTM Op only support int16 cell state",
        TfLiteTypeGetName(cell_state_type), cell_state_type);
    return kTfLiteError;
  }
  // request buffers (four buffers)
  for (size_t i = 0; i < 4; i++) {
    TF_LITE_ENSURE_OK(context, context->RequestScratchBufferInArena(
                                   context,
                                   op_data_lstm->size_info.batch_size *
                                       op_data_lstm->size_info.state_dimension *
                                       TfLiteTypeGetSize(cell_state_type),
                                   &(op_data_lstm->buffer_indices[i])));
  }

  return kTfLiteOk;
}

TfLiteStatus UnidirectionalSequenceLstmEval(TfLiteContext* context,
                                            TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& op_data = *reinterpret_cast<const OpData*>(node->user_data);
  const OpDataLSTM& op_data_lstm = op_data.params_ref;

  auto kernel_content = CreateLSTMKernelContent(context, node);

  const auto activation_type =
      kernel_content.internal_tensors[kLstmInputTensor]->type;
  const auto weight_type =
      kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;

  switch (activation_type) {
    case kTfLiteFloat32: {
      LSTMBuffers<float> buffers =
          CreateLSTMBuffers<float>(context, op_data_lstm.buffer_indices);
      EvalLstm<float, float, float, float>(op_data_lstm, kernel_content,
                                           buffers);
      break;
    }
    case kTfLiteInt8: {
      switch (weight_type) {
        case kTfLiteInt8: {
          // 8(activation)x8(weight)->16(cell) LSTM with 32 bits bias
          LSTMBuffers<int16_t> buffers =
              CreateLSTMBuffers<int16_t>(context, op_data_lstm.buffer_indices);
          return CMSIS_NN_EvalInteger8x8_16Lstm<int16_t>(
              op_data, kernel_content, buffers);
          break;
        }
        default: {
          MicroPrintf("Filter type %s (%d) not supported.",
                      TfLiteTypeGetName(weight_type), activation_type);
          return kTfLiteError;
        }
      }
      break;
    }
    case kTfLiteInt16: {
      switch (weight_type) {
        case kTfLiteInt8: {
          // 16(activation)x8(weight)->16(cell) LSTM with 64 bits bias
          LSTMBuffers<int16_t> buffers =
              CreateLSTMBuffers<int16_t>(context, op_data_lstm.buffer_indices);
          EvalLstm<int16_t, int8_t, int16_t, int64_t>(op_data_lstm,
                                                      kernel_content, buffers);
          break;
        }
        default: {
          MicroPrintf("Filter type %s (%d) not supported.",
                      TfLiteTypeGetName(weight_type), weight_type);
          return kTfLiteError;
        }
      }
      break;
    }
    default: {
      MicroPrintf("Input type %s (%d) not supported.",
                  TfLiteTypeGetName(activation_type), activation_type);
      return kTfLiteError;
    }
  }
  return kTfLiteOk;
}

TfLiteStatus UnidirectionalSequenceLstmEvalInt8(TfLiteContext* context,
                                                TfLiteNode* node) {
  TFLITE_DCHECK(node->user_data != nullptr);
  const OpData& op_data = *reinterpret_cast<const OpData*>(node->user_data);
  const OpDataLSTM& op_data_lstm = op_data.params_ref;
  auto kernel_content = CreateLSTMKernelContent(context, node);
  const auto activation_type =
      kernel_content.internal_tensors[kLstmInputTensor]->type;
  const auto weight_type =
      kernel_content.internal_tensors[kLstmInputToInputWeightsTensor]->type;

  TFLITE_DCHECK(weight_type == kTfLiteInt16 &&
                "Only int16 filter type supported.");

  if (activation_type == kTfLiteInt8) {
    LSTMBuffers<int16_t> buffers =
        CreateLSTMBuffers<int16_t>(context, op_data_lstm.buffer_indices);

    return CMSIS_NN_EvalInteger8x8_16Lstm<int16_t>(op_data, kernel_content,
                                                   buffers);
  } else {
    MicroPrintf("Input type %s (%d) not supported.",
                TfLiteTypeGetName(activation_type), activation_type);
    return kTfLiteError;
  }
  return kTfLiteOk;
}

}  // namespace

671
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
672 673 674 675 676
  return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
                                   UnidirectionalSequenceLstmPrepare,
                                   UnidirectionalSequenceLstmEval);
}

677
TFLMRegistration Register_UNIDIRECTIONAL_SEQUENCE_LSTM_INT8() {
678 679 680 681 682 683
  return tflite::micro::RegisterOp(UnidirectionalSequenceLstmInit,
                                   UnidirectionalSequenceLstmPrepare,
                                   UnidirectionalSequenceLstmEvalInt8);
}

}  // namespace tflite