未验证 提交 399b83cb 编写于 作者: I imcgraw 提交者: GitHub

Micro conformer: concatenate should support bool.

BUG=http://b/238904420
上级 07d03bfb
......@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context,
input_type == kTfLiteFloat32 || input_type == kTfLiteInt8 ||
input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
input_type == kTfLiteInt64);
input_type == kTfLiteInt64 || input_type == kTfLiteBool);
// Output type must match input type
TF_LITE_ENSURE_EQ(context, output_type, input_type);
......@@ -167,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, output != nullptr);
switch (output_type) { // Already know in/outtypes are same.
case kTfLiteBool:
case kTfLiteFloat32:
case kTfLiteInt16:
case kTfLiteInt32:
......@@ -236,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt16:
EvalUnquantized<int16_t>(context, node);
break;
case kTfLiteBool:
EvalUnquantized<bool>(context, node);
break;
default:
MicroPrintf("Op Concatenation does not currently support Type '%s'.",
......
......@@ -24,11 +24,42 @@ namespace tflite {
namespace testing {
namespace {
void TestConcatenateTwoInputs(int* input1_dims_data, const float* input1_data,
int* input2_dims_data, const float* input2_data,
int axis, int* output_dims_data,
const float* expected_output_data,
float* output_data) {
template <typename T>
void TestConcatenateOneInput(int* input1_dims_data, const T* input1_data,
int axis, int* output_dims_data, T* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
constexpr int input_size = 1;
constexpr int output_size = 1;
constexpr int tensors_size = input_size + output_size;
TfLiteTensor tensors[tensors_size] = {CreateTensor(input1_data, input1_dims),
CreateTensor(output_data, output_dims)};
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteConcatenationParams builtin_data = {
.axis = axis,
.activation = kTfLiteActNone // Only activation supported in this impl
};
const TfLiteRegistration registration =
tflite::ops::micro::Register_CONCATENATION();
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array,
reinterpret_cast<void*>(&builtin_data));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
template <typename T>
void TestConcatenateTwoInputs(int* input1_dims_data, const T* input1_data,
int* input2_dims_data, const T* input2_data,
int axis, int* output_dims_data, T* output_data) {
TfLiteIntArray* input1_dims = IntArrayFromInts(input1_dims_data);
TfLiteIntArray* input2_dims = IntArrayFromInts(input2_dims_data);
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
......@@ -58,8 +89,17 @@ void TestConcatenateTwoInputs(int* input1_dims_data, const float* input1_data,
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
}
const int output_dims_count = ElementCount(*output_dims);
void TestConcatenateTwoFloatInputs(
int* input1_dims_data, const float* input1_data, int* input2_dims_data,
const float* input2_data, int axis, int* output_dims_data,
const float* expected_output_data, float* output_data) {
TestConcatenateTwoInputs(input1_dims_data, input1_data, input2_dims_data,
input2_data, axis, output_dims_data, output_data);
TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(output_dims_data);
const int output_dims_count = ElementCount(*dims);
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f);
}
......@@ -117,6 +157,49 @@ void TestConcatenateQuantizedTwoInputs(
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(BoolTypeOneInput) {
int input_shape[] = {3, 2, 1, 2};
int output_shape[] = {3, 2, 1, 2};
const bool input_value[] = {true, false, false, true};
int axis = 1;
bool output_data[4];
tflite::testing::TestConcatenateOneInput(input_shape, input_value, axis,
output_shape, output_data);
TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(output_shape);
const int output_dims_count = tflite::ElementCount(*dims);
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(input_value[i], output_data[i]);
}
}
TF_LITE_MICRO_TEST(BoolTypeTwoInputs) {
int input1_shape[] = {3, 2, 1, 2};
const bool input1_value[] = {false, false, false, false};
int input2_shape[] = {3, 2, 3, 2};
const bool input2_value[] = {true, true, true, true, true, true,
true, true, true, true, true, true};
const bool expected_output[] = {false, false, true, true, true, true,
true, true, false, false, true, true,
true, true, true, true};
const int axis = 1;
int output_shape[] = {3, 2, 4, 2};
bool output_data[16];
tflite::testing::TestConcatenateTwoInputs(input1_shape, input1_value,
input2_shape, input2_value, axis,
output_shape, output_data);
TfLiteIntArray* dims = tflite::testing::IntArrayFromInts(output_shape);
const int output_dims_count = tflite::ElementCount(*dims);
for (int i = 0; i < output_dims_count; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output[i], output_data[i]);
}
}
TF_LITE_MICRO_TEST(TwoInputsAllAxesCombinations) {
// Concatenate the same two input tensors along all possible axes.
......@@ -137,22 +220,22 @@ TF_LITE_MICRO_TEST(TwoInputsAllAxesCombinations) {
float output_data[12];
// Axis = 0
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ 0,
output_shape_axis0, output_value_axis0, output_data);
// Axis = -2 (equivalent to axis = 0)
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ -2,
output_shape_axis0, output_value_axis0, output_data);
// Axis = 1
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ 1,
output_shape_axis1, output_value_axis1, output_data);
// Axis = -1 (equivalent to axis = 1)
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ -1,
output_shape_axis1, output_value_axis1, output_data);
}
......@@ -218,7 +301,7 @@ TF_LITE_MICRO_TEST(ThreeDimensionalTwoInputsDifferentShapes) {
9.0f, 10.0f, 11.0f, 12.0f};
float output_data[16];
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input1_shape, input1_values, input2_shape, input2_values, axis,
output_shape, output_values, output_data);
}
......@@ -240,7 +323,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f};
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ 0,
output_shape_axis0, output_value_axis0, output_data);
......@@ -250,7 +333,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0f, 2.0f, 3.0f, 13.0f, 14.0f, 15.0f, 4.0f, 5.0f,
6.0f, 16.0f, 17.0f, 18.0f, 7.0f, 8.0f, 9.0f, 19.0f,
20.0f, 21.0f, 10.0f, 11.0f, 12.0f, 22.0f, 23.0f, 24.0f};
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ 4,
output_shape_axis4, output_value_axis4, output_data);
......@@ -260,7 +343,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0f, 2.0f, 3.0f, 13.0f, 14.0f, 15.0f, 4.0f, 5.0f,
6.0f, 16.0f, 17.0f, 18.0f, 7.0f, 8.0f, 9.0f, 19.0f,
20.0f, 21.0f, 10.0f, 11.0f, 12.0f, 22.0f, 23.0f, 24.0f};
tflite::testing::TestConcatenateTwoInputs(
tflite::testing::TestConcatenateTwoFloatInputs(
input_shape, input1_value, input_shape, input2_value, /* axis */ -2,
output_shape_axis_minus2, output_value_axis_minus2, output_data);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册