/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Ops that looks up items from matrix. // // Input: // Tensor[0]: Row numbers to lookup, dim.size == 1, int32 // Tensor[1]: 2-dimensional matrix of multi-dimensional items // dim.size >= 2, all items are INT8 or FLOAT32. // first dimension is row, second dimension is column. // // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row // Each item in output is a raw bytes copy of the corresponding item in input, // or a dequantized value in the case of a INT8 input. // When indices are out of bound, the ops will not succeed. // #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/micro_utils.h" namespace tflite { namespace { constexpr int kInputTensor_0 = 0; constexpr int kInputTensor_1 = 1; constexpr int kOutputTensor = 0; struct OpData { float scale; // quantization scale for tensor 1 size_t num_columns; // number of columns after flattening tensor 1 into 2D }; TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* tensor_1, const TfLiteTensor* output) { node->user_data = context->AllocatePersistentBuffer(context, sizeof(OpData)); OpData* op_data = static_cast(node->user_data); TF_LITE_ENSURE(context, op_data != nullptr); if (tensor_1->type == kTfLiteInt8 && output->type == kTfLiteFloat32) { TF_LITE_ENSURE_EQ(context, tensor_1->params.zero_point, 0); op_data->scale = tensor_1->params.scale; } op_data->num_columns = NumElements(tensor_1) / tensor_1->dims->data[0]; return kTfLiteOk; } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); MicroContext* micro_context = GetMicroContext(context); TfLiteTensor* lookup = micro_context->AllocateTempInputTensor(node, kInputTensor_0); TF_LITE_ENSURE(context, lookup != nullptr); TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); TfLiteTensor* value = micro_context->AllocateTempInputTensor(node, kInputTensor_1); TF_LITE_ENSURE(context, value != nullptr); TF_LITE_ENSURE(context, NumDimensions(value) >= 2); TF_LITE_ENSURE(context, value->type == kTfLiteFloat32 || value->type == kTfLiteInt8); TfLiteTensor* output = micro_context->AllocateTempOutputTensor(node, kOutputTensor); TF_LITE_ENSURE(context, output != nullptr); if (value->type == kTfLiteFloat32) { TF_LITE_ENSURE(context, output->type == kTfLiteFloat32); } else { TF_LITE_ENSURE( context, output->type == kTfLiteFloat32 || output->type == kTfLiteInt8); } // make sure output dimensions size can hold the new dimension data TF_LITE_ENSURE(context, output->dims->size >= NumDimensions(value)); // make the output tensor dimensions mutable TfLiteEvalTensor* output_eval = tflite::micro::GetEvalOutput(context, node, kOutputTensor); TF_LITE_ENSURE_OK(context, tflite::micro::CreateWritableTensorDimsWithCopy( context, output, output_eval)); // set the new output dimensions output->dims->data[0] = SizeOfDimension(lookup, 0); output->dims->data[1] = SizeOfDimension(value, 1); for (int i = 2; i < NumDimensions(value); i++) { output->dims->data[i] = SizeOfDimension(value, i); } // check the new output dimensions do not exceed the output data buffer size size_t new_dims_size = NumElements(output) * TfLiteTypeGetSize(output->type); TF_LITE_ENSURE(context, new_dims_size <= output->bytes); TF_LITE_ENSURE_OK(context, CalculateOpData(context, node, value, output)); micro_context->DeallocateTempTfLiteTensor(lookup); micro_context->DeallocateTempTfLiteTensor(value); micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; } TfLiteStatus EvalSimple(const OpData& op_data, const TfLiteEvalTensor* lookup, const TfLiteEvalTensor* value, TfLiteEvalTensor* output) { const int num_rows = value->dims->data[0]; if (num_rows == 0) { // Propagate empty tensor if input is empty return kTfLiteOk; } const size_t row_bytes = op_data.num_columns * TfLiteTypeGetSize(value->type); int8_t* output_raw = tflite::micro::GetTensorData(output); const int8_t* value_raw = tflite::micro::GetTensorData(value); const int32_t* lookup_data = tflite::micro::GetTensorData(lookup); for (int i = 0; i < lookup->dims->data[0]; i++) { int32_t idx = lookup_data[i]; if (idx >= num_rows || idx < 0) { MicroPrintf( "EMBEDDING_LOOKUP: index out of bounds. " "Got %d, and bounds are [0, %d]", idx, num_rows - 1); return kTfLiteError; } else { std::memcpy(output_raw + i * row_bytes, value_raw + idx * row_bytes, row_bytes); } } return kTfLiteOk; } TfLiteStatus EvalHybrid(const OpData& op_data, const TfLiteEvalTensor* lookup, const TfLiteEvalTensor* value, TfLiteEvalTensor* output) { const int num_rows = value->dims->data[0]; const size_t num_colums = op_data.num_columns; float* output_ptr = tflite::micro::GetTensorData(output); const int8_t* value_ptr = tflite::micro::GetTensorData(value); const int32_t* lookup_data = tflite::micro::GetTensorData(lookup); for (int i = 0; i < lookup->dims->data[0]; i++) { int32_t idx = lookup_data[i]; if (idx >= num_rows || idx < 0) { MicroPrintf( "EMBEDDING_LOOKUP: index out of bounds. " "Got %d, and bounds are [0, %d]", idx, num_rows - 1); return kTfLiteError; } else { // Dequantize embedding values. Dequantize(&value_ptr[idx * num_colums], num_colums, op_data.scale, 0, &output_ptr[i * num_colums]); } } return kTfLiteOk; } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteEvalTensor* lookup = tflite::micro::GetEvalInput(context, node, kInputTensor_0); const TfLiteEvalTensor* value = tflite::micro::GetEvalInput(context, node, kInputTensor_1); TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, kOutputTensor); OpData& op_data = *static_cast(node->user_data); switch (value->type) { case kTfLiteFloat32: return EvalSimple(op_data, lookup, value, output); case kTfLiteInt8: if (output->type == kTfLiteFloat32) { return EvalHybrid(op_data, lookup, value, output); } else { return EvalSimple(op_data, lookup, value, output); } default: MicroPrintf("EMBEDDING_LOOKUP only supports FLOAT32 and INT8, got %s.", TfLiteTypeGetName(output->type)); return kTfLiteError; } } } // namespace TFLMRegistration Register_EMBEDDING_LOOKUP() { return tflite::micro::RegisterOp(nullptr, Prepare, Eval); } } // namespace tflite