提交 42ed6ac8 编写于 作者: M Mihai Maruseac

[tflite] Test for `kTfLiteOptionalTensor` in `GetInput`.

`GetInput`, `GetVariableInput` and `GetOutput` all fail to check for the case where `node->inputs->data[index]` is the special `kTfLiteOptionalTensor` value (-1) which then causes `context->tensors[node->inputs->data[index]]` to read from invalid memory location.

This fix makes `GetInput` and related return `nullptr` in those cases, asking the caller to check for `nullptr`. This is better than having `GetOptionalInputTensor` and `GetOptionalOutputTensor` (does not exist but could be added) as using the patched `GetInput` in error would be caught by a sanitizer test in the default optimized build (due to the `-fsanitize=null` option).

PiperOrigin-RevId: 332512190
Change-Id: Iabca54da2f2de02b6ece3c38b54f76d4277d689e
上级 00c7ed7c
......@@ -30,27 +30,49 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
}
inline const TfLiteTensor* GetInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->inputs->data[index]];
const int tensor_index = node->inputs->data[index];
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
// Note: You must check if result is not null:
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
const TfLiteNode* node, int index) {
TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
const int tensor_index = node->inputs->data[index];
if (tensor_index < 0) {
return nullptr;
}
TfLiteTensor* tensor = &context->tensors[tensor_index];
>>>>>>> d8f8236c29 ([tflite] Test for `kTfLiteOptionalTensor` in `GetInput`.)
return (tensor->is_variable) ? tensor : nullptr;
}
inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
const int tensor_index = node->outputs->data[index];
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->temporaries->data[index]];
const int tensor_index = node->temporaries->data[index];
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]];
const int tensor_index = node->intermediates->data[index];
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册