未验证 提交 78b4040c 编写于 作者: D David Davis 提交者: GitHub

Fix SELECT_V2 Prepare deallocation error (#2013)

@tensorflow/micro

Fix logic error in Prepare that prevented deallocation of temporary tensors, when scalars were present in the inputs. 

Added 3 tests from TfLite for scalar inputs.

bug=fixes #2010
上级 2b7f86c8
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -101,16 +101,15 @@ TfLiteStatus SelectPrepare(TfLiteContext* context, TfLiteNode* node) {
// Respect the original output shape when there are mixed shapes to represent
// a scalar data.
if (GetTensorShape(input_condition).FlatSize() == 1 &&
bool possible_mixed_scaler =
GetTensorShape(input_condition).FlatSize() == 1 &&
GetTensorShape(input_x).FlatSize() == 1 &&
GetTensorShape(input_y).FlatSize() == 1 &&
GetTensorShape(output).FlatSize() == 1) {
return kTfLiteOk;
}
GetTensorShape(output).FlatSize() == 1;
bool same_shape = HaveSameShapes(input_condition, input_x) &&
HaveSameShapes(input_x, input_y);
if (!same_shape) {
if (!same_shape && !possible_mixed_scaler) {
TF_LITE_ENSURE_OK(
context, CheckBroadcastShape(context, input_condition, input_x, input_y,
output->dims));
......
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <type_traits>
#include "tensorflow/lite/micro/kernels/kernel_runner.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"
......@@ -207,4 +209,55 @@ TF_LITE_MICRO_TEST(BroadcastSelectInt16OneDimensionConditionWithTwoValues) {
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputConditionTensor) {
int input1_shape[] = {0}; // conditional data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {false};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {5};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input1_shape, input1_data, input_shape,
input2_data, input_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputXTensor) {
int input2_shape[] = {0}; // x data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {true};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {1};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input_shape, input1_data, input2_shape,
input2_data, input_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TEST(MixedFlatSizeOneInputsWithScalarInputYTensor) {
int input3_shape[] = {0}; // y data is a scalar
int input_shape[] = {1, 1};
int output_shape[] = {0}; // output data is a scalar
const bool input1_data[] = {false};
const int16_t input2_data[] = {1};
const int16_t input3_data[] = {5};
const int16_t expected_output[] = {5};
int16_t output_data[std::extent<decltype(expected_output)>::value];
tflite::testing::TestSelect(input_shape, input1_data, input_shape,
input2_data, input3_shape, input3_data,
output_shape, output_data);
tflite::testing::ExpectEqual(output_shape, expected_output, output_data);
}
TF_LITE_MICRO_TESTS_END
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册