Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5969769b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5969769b
编写于
8月 04, 2020
作者:
S
sunsuodong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
topk_int8
上级
c17ed236
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
367 addition
and
34 deletion
+367
-34
mindspore/lite/src/ops/topk.cc
mindspore/lite/src/ops/topk.cc
+7
-5
mindspore/lite/src/populate_parameter.cc
mindspore/lite/src/populate_parameter.cc
+1
-1
mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc
mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc
+12
-12
mindspore/lite/src/runtime/kernel/arm/fp32/topk.h
mindspore/lite/src/runtime/kernel/arm/fp32/topk.h
+5
-6
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
+76
-0
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
+42
-0
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.cc
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.cc
+6
-6
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h
+4
-4
mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc
...pore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc
+54
-0
mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h
...spore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h
+30
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc
...te/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc
+65
-0
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc
...te/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc
+65
-0
未找到文件。
mindspore/lite/src/ops/topk.cc
浏览文件 @
5969769b
...
...
@@ -35,13 +35,15 @@ int TopK::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
auto
topk_prim
=
this
->
primitive
->
value_as_TopK
();
MS_ASSERT
(
topk_prim
!=
nullptr
);
output0
->
set_shape
(
input
->
shape
());
auto
out_shape
=
input
->
shape
();
out_shape
[
out_shape
.
size
()
-
1
]
=
topk_prim
->
k
();
output0
->
set_shape
(
out_shape
);
output0
->
set_data_type
(
input
->
data_type
());
// output0->shape().back() = topk_prim->k(
);
output0
->
SetFormat
(
input
->
GetFormat
()
);
output1
->
set_shape
(
input
->
shape
());
output1
->
set_data_type
(
input
->
data_type
());
// output1->shape().back() = topk_prim->k();
output1
->
set_shape
(
out_shape
);
output1
->
set_data_type
(
kNumberTypeInt32
);
output1
->
SetFormat
(
input
->
GetFormat
());
return
RET_OK
;
...
...
mindspore/lite/src/populate_parameter.cc
浏览文件 @
5969769b
...
...
@@ -34,7 +34,7 @@
#include "src/runtime/kernel/arm/opclib/matmul.h"
#include "src/runtime/kernel/arm/opclib/fp32/softmax.h"
#include "src/runtime/kernel/arm/opclib/tile.h"
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/
fp32/
topk.h"
#include "src/runtime/kernel/arm/opclib/fp32/reduce.h"
#include "src/runtime/kernel/arm/opclib/fp32/activation.h"
#include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h"
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc
浏览文件 @
5969769b
...
...
@@ -25,11 +25,18 @@ using mindspore::schema::PrimitiveType_TopK;
namespace
mindspore
::
kernel
{
int
TopKCPUKernel
::
Init
()
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
opParameter
);
lite
::
tensor
::
Tensor
*
input
=
inputs_
.
at
(
0
);
topk_parameter_
->
last_dim_size_
=
input
->
shape
()[
input
->
shape
().
size
()
-
1
];
topk_parameter_
->
loop_num_
=
1
;
parameter
->
last_dim_size_
=
input
->
shape
()[
input
->
shape
().
size
()
-
1
];
parameter
->
loop_num_
=
1
;
for
(
int
i
=
0
;
i
<
input
->
shape
().
size
()
-
1
;
++
i
)
{
topk_parameter_
->
loop_num_
*=
input
->
shape
()[
i
];
parameter
->
loop_num_
*=
input
->
shape
()[
i
];
}
parameter
->
topk_node_list_
=
malloc
(
sizeof
(
TopkNode
)
*
parameter
->
last_dim_size_
);
if
(
parameter
->
topk_node_list_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"malloc fail."
;
return
RET_ERROR
;
}
return
RET_OK
;
}
...
...
@@ -39,14 +46,9 @@ int TopKCPUKernel::ReSize() { return RET_OK; }
int
TopKCPUKernel
::
Run
()
{
auto
input_data
=
reinterpret_cast
<
float
*>
(
inputs_
.
at
(
0
)
->
Data
());
auto
output_data
=
reinterpret_cast
<
float
*>
(
outputs_
.
at
(
0
)
->
Data
());
auto
output_index
=
reinterpret_cast
<
floa
t
*>
(
outputs_
.
at
(
1
)
->
Data
());
auto
output_index
=
reinterpret_cast
<
int32_
t
*>
(
outputs_
.
at
(
1
)
->
Data
());
Node
*
top_map
=
reinterpret_cast
<
Node
*>
(
malloc
(
sizeof
(
Node
)
*
topk_parameter_
->
last_dim_size_
));
MS_EXCEPTION_IF_NULL
(
top_map
);
topk_parameter_
->
topk_node_list_
=
top_map
;
Topk
(
input_data
,
output_data
,
output_index
,
topk_parameter_
);
free
(
top_map
);
topk_parameter_
->
topk_node_list_
=
nullptr
;
Topk
(
input_data
,
output_data
,
output_index
,
reinterpret_cast
<
TopkParameter
*>
(
opParameter
));
return
RET_OK
;
}
...
...
@@ -54,7 +56,6 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Ten
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
parameter
,
const
lite
::
Context
*
ctx
,
const
KernelKey
&
desc
)
{
MS_ASSERT
(
parameter
!=
nullptr
);
MS_ASSERT
(
desc
.
type
==
PrimitiveType_Tile
);
auto
*
kernel
=
new
(
std
::
nothrow
)
TopKCPUKernel
(
parameter
,
inputs
,
outputs
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TopKCPUKernel fail!"
;
...
...
@@ -73,4 +74,3 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector<lite::tensor::Ten
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_TopK
,
CpuTopKFp32KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/topk.h
浏览文件 @
5969769b
...
...
@@ -18,26 +18,25 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/
fp32/
topk.h"
namespace
mindspore
::
kernel
{
class
TopKCPUKernel
:
public
LiteKernel
{
public:
explicit
TopKCPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
)
{
topk_parameter_
=
reinterpret_cast
<
TopkParameter
*>
(
parameter
);
:
LiteKernel
(
parameter
,
inputs
,
outputs
)
{}
~
TopKCPUKernel
()
override
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
opParameter
);
free
(
parameter
->
topk_node_list_
);
}
~
TopKCPUKernel
()
override
{}
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
private:
TopkParameter
*
topk_parameter_
;
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/int8/topk_int8.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using
mindspore
::
lite
::
KernelRegistrar
;
using
mindspore
::
lite
::
RET_ERROR
;
using
mindspore
::
lite
::
RET_OK
;
using
mindspore
::
schema
::
PrimitiveType_TopK
;
namespace
mindspore
::
kernel
{
int
TopKInt8CPUKernel
::
Init
()
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
opParameter
);
lite
::
tensor
::
Tensor
*
input
=
inputs_
.
at
(
0
);
parameter
->
last_dim_size_
=
input
->
shape
()[
input
->
shape
().
size
()
-
1
];
parameter
->
loop_num_
=
1
;
for
(
int
i
=
0
;
i
<
input
->
shape
().
size
()
-
1
;
++
i
)
{
parameter
->
loop_num_
*=
input
->
shape
()[
i
];
}
parameter
->
topk_node_list_
=
malloc
(
sizeof
(
TopkNodeInt8
)
*
parameter
->
last_dim_size_
);
if
(
parameter
->
topk_node_list_
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"malloc fail."
;
return
RET_ERROR
;
}
return
RET_OK
;
}
int
TopKInt8CPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
TopKInt8CPUKernel
::
Run
()
{
int8_t
*
input_data
=
reinterpret_cast
<
int8_t
*>
(
inputs_
.
at
(
0
)
->
Data
());
int8_t
*
output_data
=
reinterpret_cast
<
int8_t
*>
(
outputs_
.
at
(
0
)
->
Data
());
int32_t
*
output_index
=
reinterpret_cast
<
int32_t
*>
(
outputs_
.
at
(
1
)
->
Data
());
TopkInt8
(
input_data
,
output_data
,
output_index
,
reinterpret_cast
<
TopkParameter
*>
(
opParameter
));
return
RET_OK
;
}
kernel
::
LiteKernel
*
CpuTopKInt8KernelCreator
(
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
OpParameter
*
parameter
,
const
lite
::
Context
*
ctx
,
const
KernelKey
&
desc
)
{
MS_ASSERT
(
parameter
!=
nullptr
);
auto
*
kernel
=
new
(
std
::
nothrow
)
TopKInt8CPUKernel
(
parameter
,
inputs
,
outputs
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new TopKInt8CPUKernel fail!"
;
return
nullptr
;
}
auto
ret
=
kernel
->
Init
();
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"Init kernel failed, name: "
<<
parameter
->
name_
<<
", type: "
<<
schema
::
EnumNamePrimitiveType
(
static_cast
<
schema
::
PrimitiveType
>
(
parameter
->
type_
));
delete
kernel
;
return
nullptr
;
}
return
kernel
;
}
REG_KERNEL
(
kCPU
,
kNumberTypeInt8
,
PrimitiveType_TopK
,
CpuTopKInt8KernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h"
namespace
mindspore
::
kernel
{
class
TopKInt8CPUKernel
:
public
LiteKernel
{
public:
explicit
TopKInt8CPUKernel
(
OpParameter
*
parameter
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
inputs
,
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
)
:
LiteKernel
(
parameter
,
inputs
,
outputs
)
{}
~
TopKInt8CPUKernel
()
override
{
TopkParameter
*
parameter
=
reinterpret_cast
<
TopkParameter
*>
(
opParameter
);
free
(
parameter
->
topk_node_list_
);
}
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
private:
};
}
// namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_
mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc
→
mindspore/lite/src/runtime/kernel/arm/opclib/
fp32/
topk.cc
浏览文件 @
5969769b
...
...
@@ -14,25 +14,25 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/topk.h"
#include "src/runtime/kernel/arm/opclib/
fp32/
topk.h"
int
DescendCmp
(
const
void
*
a
,
const
void
*
b
)
{
return
((
const
Node
*
)
b
)
->
element
-
((
const
Node
*
)
a
)
->
element
;
return
((
const
TopkNode
*
)
b
)
->
element
-
((
const
Topk
Node
*
)
a
)
->
element
;
}
int
AscendCmp
(
const
void
*
a
,
const
void
*
b
)
{
return
((
const
Node
*
)
a
)
->
element
-
((
const
Node
*
)
b
)
->
element
;
return
((
const
TopkNode
*
)
a
)
->
element
-
((
const
Topk
Node
*
)
b
)
->
element
;
}
void
Topk
(
float
*
input_data
,
float
*
output_data
,
floa
t
*
output_index
,
TopkParameter
*
parameter
)
{
void
Topk
(
float
*
input_data
,
float
*
output_data
,
int32_
t
*
output_index
,
TopkParameter
*
parameter
)
{
int
last_dim_size
=
parameter
->
last_dim_size_
;
int
loop_num
=
parameter
->
loop_num_
;
int
k
=
parameter
->
k_
;
Node
*
top_map
=
parameter
->
topk_node_list_
;
TopkNode
*
top_map
=
(
TopkNode
*
)
parameter
->
topk_node_list_
;
float
*
cur_input_data
=
input_data
;
float
*
cur_output_data
=
output_data
;
floa
t
*
cur_output_index
=
output_index
;
int32_
t
*
cur_output_index
=
output_index
;
for
(
int
i
=
0
;
i
<
loop_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
last_dim_size
;
j
++
)
{
top_map
[
j
].
element
=
*
(
cur_input_data
+
j
);
...
...
mindspore/lite/src/runtime/kernel/arm/opclib/topk.h
→
mindspore/lite/src/runtime/kernel/arm/opclib/
fp32/
topk.h
浏览文件 @
5969769b
...
...
@@ -19,9 +19,9 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
struct
Node
{
struct
Topk
Node
{
float
element
;
floa
t
index
;
int32_
t
index
;
};
struct
TopkParameter
{
...
...
@@ -30,10 +30,10 @@ struct TopkParameter {
int
loop_num_
;
int
k_
;
bool
sorted_
;
Node
*
topk_node_list_
;
void
*
topk_node_list_
;
};
void
Topk
(
float
*
input_data
,
float
*
output_data
,
floa
t
*
output_index
,
TopkParameter
*
parameter
);
void
Topk
(
float
*
input_data
,
float
*
output_data
,
int32_
t
*
output_index
,
TopkParameter
*
parameter
);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_
mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h"
int
DescendCmpInt8
(
const
void
*
a
,
const
void
*
b
)
{
return
((
const
TopkNodeInt8
*
)
b
)
->
element
-
((
const
TopkNodeInt8
*
)
a
)
->
element
;
}
int
AscendCmpInt8
(
const
void
*
a
,
const
void
*
b
)
{
return
((
const
TopkNodeInt8
*
)
a
)
->
element
-
((
const
TopkNodeInt8
*
)
b
)
->
element
;
}
void
TopkInt8
(
int8_t
*
input_data
,
int8_t
*
output_data
,
int32_t
*
output_index
,
TopkParameter
*
parameter
)
{
int
last_dim_size
=
parameter
->
last_dim_size_
;
int
loop_num
=
parameter
->
loop_num_
;
int
k
=
parameter
->
k_
;
TopkNodeInt8
*
top_map
=
(
TopkNodeInt8
*
)
parameter
->
topk_node_list_
;
int8_t
*
cur_input_data
=
input_data
;
int8_t
*
cur_output_data
=
output_data
;
int32_t
*
cur_output_index
=
output_index
;
for
(
int
i
=
0
;
i
<
loop_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
last_dim_size
;
j
++
)
{
top_map
[
j
].
element
=
*
(
cur_input_data
+
j
);
top_map
[
j
].
index
=
j
;
}
if
(
parameter
->
sorted_
)
{
qsort
(
top_map
,
last_dim_size
,
sizeof
(
top_map
[
0
]),
DescendCmpInt8
);
}
else
{
qsort
(
top_map
,
last_dim_size
,
sizeof
(
top_map
[
0
]),
AscendCmpInt8
);
}
for
(
int
m
=
0
;
m
<
k
;
m
++
)
{
cur_output_data
[
m
]
=
top_map
[
m
].
element
;
cur_output_index
[
m
]
=
top_map
[
m
].
index
;
}
cur_input_data
+=
last_dim_size
;
cur_output_data
+=
k
;
cur_output_index
+=
k
;
}
}
mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/fp32/topk.h"
struct
TopkNodeInt8
{
int8_t
element
;
int32_t
index
;
};
void
TopkInt8
(
int8_t
*
input_data
,
int8_t
*
output_data
,
int32_t
*
output_index
,
TopkParameter
*
parameter
);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace
mindspore
{
class
TestTopKFp32
:
public
mindspore
::
Common
{
public:
TestTopKFp32
()
{}
};
TEST_F
(
TestTopKFp32
,
TopK
)
{
lite
::
tensor
::
Tensor
in_tensor
(
kNumberTypeFloat32
,
{
2
,
2
,
3
});
lite
::
tensor
::
Tensor
out_tensor0
(
kNumberTypeFloat32
,
{
2
,
2
,
2
});
lite
::
tensor
::
Tensor
out_tensor1
(
kNumberTypeInt32
,
{
2
,
2
,
2
});
float
input_data
[]
=
{
1
,
2
,
3
,
6
,
5
,
4
,
9
,
8
,
7
,
10
,
12
,
11
};
float
output_data0
[
8
]
=
{
0
};
int32_t
output_data1
[
8
]
=
{
0
};
in_tensor
.
SetData
(
input_data
);
out_tensor0
.
SetData
(
output_data0
);
out_tensor1
.
SetData
(
output_data1
);
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs
=
{
&
in_tensor
};
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs
=
{
&
out_tensor0
,
&
out_tensor1
};
TopkParameter
parameter
=
{{},
3
,
4
,
2
,
true
};
kernel
::
KernelKey
desc
=
{
kernel
::
KERNEL_ARCH
::
kCPU
,
kNumberTypeFloat32
,
schema
::
PrimitiveType_TopK
};
auto
creator
=
lite
::
KernelRegistry
::
GetInstance
()
->
GetCreator
(
desc
);
ASSERT_NE
(
creator
,
nullptr
);
auto
kernel
=
creator
(
inputs
,
outputs
,
reinterpret_cast
<
OpParameter
*>
(
&
parameter
),
nullptr
,
desc
);
ASSERT_NE
(
kernel
,
nullptr
);
auto
ret
=
kernel
->
Run
();
EXPECT_EQ
(
0
,
ret
);
float
expect0
[]
=
{
3
,
2
,
6
,
5
,
9
,
8
,
12
,
11
};
int32_t
expect1
[]
=
{
2
,
1
,
0
,
1
,
0
,
1
,
1
,
2
};
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
EXPECT_EQ
(
output_data0
[
i
],
expect0
[
i
]);
EXPECT_EQ
(
output_data1
[
i
],
expect1
[
i
]);
}
in_tensor
.
SetData
(
nullptr
);
out_tensor0
.
SetData
(
nullptr
);
out_tensor1
.
SetData
(
nullptr
);
}
}
// namespace mindspore
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc
0 → 100644
浏览文件 @
5969769b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h"
#include "mindspore/lite/src/kernel_registry.h"
namespace
mindspore
{
class
TestTopKInt8
:
public
mindspore
::
Common
{
public:
TestTopKInt8
()
{}
};
TEST_F
(
TestTopKInt8
,
TopK
)
{
lite
::
tensor
::
Tensor
in_tensor
(
kNumberTypeInt8
,
{
2
,
2
,
3
});
lite
::
tensor
::
Tensor
out_tensor0
(
kNumberTypeInt8
,
{
2
,
2
,
2
});
lite
::
tensor
::
Tensor
out_tensor1
(
kNumberTypeInt32
,
{
2
,
2
,
2
});
int8_t
input_data
[]
=
{
1
,
2
,
3
,
6
,
5
,
4
,
9
,
8
,
7
,
10
,
12
,
11
};
int8_t
output_data0
[
8
]
=
{
0
};
int32_t
output_data1
[
8
]
=
{
0
};
in_tensor
.
SetData
(
input_data
);
out_tensor0
.
SetData
(
output_data0
);
out_tensor1
.
SetData
(
output_data1
);
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs
=
{
&
in_tensor
};
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs
=
{
&
out_tensor0
,
&
out_tensor1
};
TopkParameter
parameter
=
{{},
3
,
4
,
2
,
true
};
kernel
::
KernelKey
desc
=
{
kernel
::
KERNEL_ARCH
::
kCPU
,
kNumberTypeInt8
,
schema
::
PrimitiveType_TopK
};
auto
creator
=
lite
::
KernelRegistry
::
GetInstance
()
->
GetCreator
(
desc
);
ASSERT_NE
(
creator
,
nullptr
);
auto
kernel
=
creator
(
inputs
,
outputs
,
reinterpret_cast
<
OpParameter
*>
(
&
parameter
),
nullptr
,
desc
);
ASSERT_NE
(
kernel
,
nullptr
);
auto
ret
=
kernel
->
Run
();
EXPECT_EQ
(
0
,
ret
);
int8_t
expect0
[]
=
{
3
,
2
,
6
,
5
,
9
,
8
,
12
,
11
};
int32_t
expect1
[]
=
{
2
,
1
,
0
,
1
,
0
,
1
,
1
,
2
};
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
EXPECT_EQ
(
output_data0
[
i
],
expect0
[
i
]);
EXPECT_EQ
(
output_data1
[
i
],
expect1
[
i
]);
}
in_tensor
.
SetData
(
nullptr
);
out_tensor0
.
SetData
(
nullptr
);
out_tensor1
.
SetData
(
nullptr
);
}
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录