Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
码匠许师傅
Tflite Micro
提交
399b83cb
T
Tflite Micro
项目概览
码匠许师傅
/
Tflite Micro
11 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tflite Micro
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
399b83cb
编写于
8月 12, 2022
作者:
I
imcgraw
提交者:
GitHub
8月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Micro conformer: concatenate should support bool.
BUG=
http://b/238904420
上级
07d03bfb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
102 addition
and
15 deletion
+102
-15
tensorflow/lite/micro/kernels/concatenation.cc
tensorflow/lite/micro/kernels/concatenation.cc
+5
-1
tensorflow/lite/micro/kernels/concatenation_test.cc
tensorflow/lite/micro/kernels/concatenation_test.cc
+97
-14
未找到文件。
tensorflow/lite/micro/kernels/concatenation.cc
浏览文件 @
399b83cb
...
...
@@ -133,7 +133,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE
(
context
,
input_type
==
kTfLiteFloat32
||
input_type
==
kTfLiteInt8
||
input_type
==
kTfLiteInt16
||
input_type
==
kTfLiteInt32
||
input_type
==
kTfLiteInt64
);
input_type
==
kTfLiteInt64
||
input_type
==
kTfLiteBool
);
// Output type must match input type
TF_LITE_ENSURE_EQ
(
context
,
output_type
,
input_type
);
...
...
@@ -167,6 +167,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE
(
context
,
output
!=
nullptr
);
switch
(
output_type
)
{
// Already know in/outtypes are same.
case
kTfLiteBool
:
case
kTfLiteFloat32
:
case
kTfLiteInt16
:
case
kTfLiteInt32
:
...
...
@@ -236,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case
kTfLiteInt16
:
EvalUnquantized
<
int16_t
>
(
context
,
node
);
break
;
case
kTfLiteBool
:
EvalUnquantized
<
bool
>
(
context
,
node
);
break
;
default:
MicroPrintf
(
"Op Concatenation does not currently support Type '%s'."
,
...
...
tensorflow/lite/micro/kernels/concatenation_test.cc
浏览文件 @
399b83cb
...
...
@@ -24,11 +24,42 @@ namespace tflite {
namespace
testing
{
namespace
{
void
TestConcatenateTwoInputs
(
int
*
input1_dims_data
,
const
float
*
input1_data
,
int
*
input2_dims_data
,
const
float
*
input2_data
,
int
axis
,
int
*
output_dims_data
,
const
float
*
expected_output_data
,
float
*
output_data
)
{
template
<
typename
T
>
void
TestConcatenateOneInput
(
int
*
input1_dims_data
,
const
T
*
input1_data
,
int
axis
,
int
*
output_dims_data
,
T
*
output_data
)
{
TfLiteIntArray
*
input1_dims
=
IntArrayFromInts
(
input1_dims_data
);
TfLiteIntArray
*
output_dims
=
IntArrayFromInts
(
output_dims_data
);
constexpr
int
input_size
=
1
;
constexpr
int
output_size
=
1
;
constexpr
int
tensors_size
=
input_size
+
output_size
;
TfLiteTensor
tensors
[
tensors_size
]
=
{
CreateTensor
(
input1_data
,
input1_dims
),
CreateTensor
(
output_data
,
output_dims
)};
int
inputs_array_data
[]
=
{
1
,
0
};
TfLiteIntArray
*
inputs_array
=
IntArrayFromInts
(
inputs_array_data
);
int
outputs_array_data
[]
=
{
1
,
1
};
TfLiteIntArray
*
outputs_array
=
IntArrayFromInts
(
outputs_array_data
);
TfLiteConcatenationParams
builtin_data
=
{
.
axis
=
axis
,
.
activation
=
kTfLiteActNone
// Only activation supported in this impl
};
const
TfLiteRegistration
registration
=
tflite
::
ops
::
micro
::
Register_CONCATENATION
();
micro
::
KernelRunner
runner
(
registration
,
tensors
,
tensors_size
,
inputs_array
,
outputs_array
,
reinterpret_cast
<
void
*>
(
&
builtin_data
));
TF_LITE_MICRO_EXPECT_EQ
(
kTfLiteOk
,
runner
.
InitAndPrepare
());
TF_LITE_MICRO_EXPECT_EQ
(
kTfLiteOk
,
runner
.
Invoke
());
}
template
<
typename
T
>
void
TestConcatenateTwoInputs
(
int
*
input1_dims_data
,
const
T
*
input1_data
,
int
*
input2_dims_data
,
const
T
*
input2_data
,
int
axis
,
int
*
output_dims_data
,
T
*
output_data
)
{
TfLiteIntArray
*
input1_dims
=
IntArrayFromInts
(
input1_dims_data
);
TfLiteIntArray
*
input2_dims
=
IntArrayFromInts
(
input2_dims_data
);
TfLiteIntArray
*
output_dims
=
IntArrayFromInts
(
output_dims_data
);
...
...
@@ -58,8 +89,17 @@ void TestConcatenateTwoInputs(int* input1_dims_data, const float* input1_data,
TF_LITE_MICRO_EXPECT_EQ
(
kTfLiteOk
,
runner
.
InitAndPrepare
());
TF_LITE_MICRO_EXPECT_EQ
(
kTfLiteOk
,
runner
.
Invoke
());
}
const
int
output_dims_count
=
ElementCount
(
*
output_dims
);
void
TestConcatenateTwoFloatInputs
(
int
*
input1_dims_data
,
const
float
*
input1_data
,
int
*
input2_dims_data
,
const
float
*
input2_data
,
int
axis
,
int
*
output_dims_data
,
const
float
*
expected_output_data
,
float
*
output_data
)
{
TestConcatenateTwoInputs
(
input1_dims_data
,
input1_data
,
input2_dims_data
,
input2_data
,
axis
,
output_dims_data
,
output_data
);
TfLiteIntArray
*
dims
=
tflite
::
testing
::
IntArrayFromInts
(
output_dims_data
);
const
int
output_dims_count
=
ElementCount
(
*
dims
);
for
(
int
i
=
0
;
i
<
output_dims_count
;
++
i
)
{
TF_LITE_MICRO_EXPECT_NEAR
(
expected_output_data
[
i
],
output_data
[
i
],
1e-5
f
);
}
...
...
@@ -117,6 +157,49 @@ void TestConcatenateQuantizedTwoInputs(
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST
(
BoolTypeOneInput
)
{
int
input_shape
[]
=
{
3
,
2
,
1
,
2
};
int
output_shape
[]
=
{
3
,
2
,
1
,
2
};
const
bool
input_value
[]
=
{
true
,
false
,
false
,
true
};
int
axis
=
1
;
bool
output_data
[
4
];
tflite
::
testing
::
TestConcatenateOneInput
(
input_shape
,
input_value
,
axis
,
output_shape
,
output_data
);
TfLiteIntArray
*
dims
=
tflite
::
testing
::
IntArrayFromInts
(
output_shape
);
const
int
output_dims_count
=
tflite
::
ElementCount
(
*
dims
);
for
(
int
i
=
0
;
i
<
output_dims_count
;
++
i
)
{
TF_LITE_MICRO_EXPECT_EQ
(
input_value
[
i
],
output_data
[
i
]);
}
}
TF_LITE_MICRO_TEST
(
BoolTypeTwoInputs
)
{
int
input1_shape
[]
=
{
3
,
2
,
1
,
2
};
const
bool
input1_value
[]
=
{
false
,
false
,
false
,
false
};
int
input2_shape
[]
=
{
3
,
2
,
3
,
2
};
const
bool
input2_value
[]
=
{
true
,
true
,
true
,
true
,
true
,
true
,
true
,
true
,
true
,
true
,
true
,
true
};
const
bool
expected_output
[]
=
{
false
,
false
,
true
,
true
,
true
,
true
,
true
,
true
,
false
,
false
,
true
,
true
,
true
,
true
,
true
,
true
};
const
int
axis
=
1
;
int
output_shape
[]
=
{
3
,
2
,
4
,
2
};
bool
output_data
[
16
];
tflite
::
testing
::
TestConcatenateTwoInputs
(
input1_shape
,
input1_value
,
input2_shape
,
input2_value
,
axis
,
output_shape
,
output_data
);
TfLiteIntArray
*
dims
=
tflite
::
testing
::
IntArrayFromInts
(
output_shape
);
const
int
output_dims_count
=
tflite
::
ElementCount
(
*
dims
);
for
(
int
i
=
0
;
i
<
output_dims_count
;
++
i
)
{
TF_LITE_MICRO_EXPECT_EQ
(
expected_output
[
i
],
output_data
[
i
]);
}
}
TF_LITE_MICRO_TEST
(
TwoInputsAllAxesCombinations
)
{
// Concatenate the same two input tensors along all possible axes.
...
...
@@ -137,22 +220,22 @@ TF_LITE_MICRO_TEST(TwoInputsAllAxesCombinations) {
float
output_data
[
12
];
// Axis = 0
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
0
,
output_shape_axis0
,
output_value_axis0
,
output_data
);
// Axis = -2 (equivalent to axis = 0)
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
-
2
,
output_shape_axis0
,
output_value_axis0
,
output_data
);
// Axis = 1
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
1
,
output_shape_axis1
,
output_value_axis1
,
output_data
);
// Axis = -1 (equivalent to axis = 1)
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
-
1
,
output_shape_axis1
,
output_value_axis1
,
output_data
);
}
...
...
@@ -218,7 +301,7 @@ TF_LITE_MICRO_TEST(ThreeDimensionalTwoInputsDifferentShapes) {
9.0
f
,
10.0
f
,
11.0
f
,
12.0
f
};
float
output_data
[
16
];
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input1_shape
,
input1_values
,
input2_shape
,
input2_values
,
axis
,
output_shape
,
output_values
,
output_data
);
}
...
...
@@ -240,7 +323,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0
f
,
2.0
f
,
3.0
f
,
4.0
f
,
5.0
f
,
6.0
f
,
7.0
f
,
8.0
f
,
9.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
16.0
f
,
17.0
f
,
18.0
f
,
19.0
f
,
20.0
f
,
21.0
f
,
22.0
f
,
23.0
f
,
24.0
f
};
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
0
,
output_shape_axis0
,
output_value_axis0
,
output_data
);
...
...
@@ -250,7 +333,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0
f
,
2.0
f
,
3.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
4.0
f
,
5.0
f
,
6.0
f
,
16.0
f
,
17.0
f
,
18.0
f
,
7.0
f
,
8.0
f
,
9.0
f
,
19.0
f
,
20.0
f
,
21.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
22.0
f
,
23.0
f
,
24.0
f
};
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
4
,
output_shape_axis4
,
output_value_axis4
,
output_data
);
...
...
@@ -260,7 +343,7 @@ TF_LITE_MICRO_TEST(TwoInputsFiveDimensionsAllAxesCombinations) {
1.0
f
,
2.0
f
,
3.0
f
,
13.0
f
,
14.0
f
,
15.0
f
,
4.0
f
,
5.0
f
,
6.0
f
,
16.0
f
,
17.0
f
,
18.0
f
,
7.0
f
,
8.0
f
,
9.0
f
,
19.0
f
,
20.0
f
,
21.0
f
,
10.0
f
,
11.0
f
,
12.0
f
,
22.0
f
,
23.0
f
,
24.0
f
};
tflite
::
testing
::
TestConcatenateTwoInputs
(
tflite
::
testing
::
TestConcatenateTwo
Float
Inputs
(
input_shape
,
input1_value
,
input_shape
,
input2_value
,
/* axis */
-
2
,
output_shape_axis_minus2
,
output_value_axis_minus2
,
output_data
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录