Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
756b8346
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
756b8346
编写于
8月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4263 [MS][LITE][Develop]argmax,argmin support keepdim
Merge pull request !4263 from chenjianping/lite_dev
上级
e203503b
746cd9cf
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
40 addition
and
6 deletion
+40
-6
mindspore/lite/src/ops/argmax.cc
mindspore/lite/src/ops/argmax.cc
+2
-2
mindspore/lite/src/ops/argmin.cc
mindspore/lite/src/ops/argmin.cc
+2
-2
mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc
...pore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc
+5
-1
mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc
mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc
+1
-1
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
...est/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
+30
-0
未找到文件。
mindspore/lite/src/ops/argmax.cc
浏览文件 @
756b8346
...
...
@@ -38,9 +38,9 @@ int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
MS_LOG
(
ERROR
)
<<
"Invalid axis "
<<
argmax_prim
->
axis
()
<<
", input shape size: "
<<
input_shape_size
;
return
RET_PARAM_INVALID
;
}
if
(
argmax_prim
->
topK
()
==
1
)
{
if
(
argmax_prim
->
topK
()
==
1
&&
!
argmax_prim
->
keepDims
()
)
{
output_shape
.
erase
(
output_shape
.
begin
()
+
axis
);
}
else
if
(
argmax_prim
->
axisType
()
==
1
)
{
}
else
{
output_shape
[
axis
]
=
argmax_prim
->
topK
();
}
...
...
mindspore/lite/src/ops/argmin.cc
浏览文件 @
756b8346
...
...
@@ -37,9 +37,9 @@ int ArgMin::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
return
RET_PARAM_INVALID
;
}
std
::
vector
<
int
>
output_shape
(
input
->
shape
());
if
(
argmin_prim
->
topK
()
==
1
)
{
if
(
argmin_prim
->
topK
()
==
1
&&
!
argmin_prim
->
keepDims
()
)
{
output_shape
.
erase
(
output_shape
.
begin
()
+
axis
);
}
else
if
(
argmin_prim
->
axisType
()
==
1
)
{
}
else
{
output_shape
[
axis
]
=
argmin_prim
->
topK
();
}
...
...
mindspore/lite/src/runtime/kernel/arm/base/arg_min_max_base.cc
浏览文件 @
756b8346
...
...
@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/nnacl/arg_min_max.h"
#include "src/runtime/kernel/arm/fp32/argminmax.h"
#include "src/runtime/kernel/arm/int8/argminmax_int8.h"
#include "src/runtime/kernel/arm/nnacl/arithmetic_common.h"
#include "schema/model_generated.h"
#include "src/kernel_factory.h"
#include "include/errorcode.h"
...
...
@@ -60,7 +61,7 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
return
RET_PARAM_INVALID
;
}
param
->
topk_
=
MSMIN
(
param
->
topk_
,
in_shape
[
axis
]);
if
(
param
->
topk_
>
1
)
{
if
(
param
->
topk_
>
1
||
param
->
keep_dims_
)
{
if
(
context_
!=
nullptr
&&
context_
->
allocator
!=
nullptr
)
{
param
->
arg_elements_
=
reinterpret_cast
<
ArgElement
*>
(
context_
->
allocator
->
Malloc
(
sizeof
(
ArgElement
)
*
in_shape
[
axis
]));
...
...
@@ -73,6 +74,9 @@ int ArgMinMaxBaseCPUKernel::ReSize() {
return
RET_ERROR
;
}
}
ComputeStrides
(
in_shape
.
data
(),
param
->
in_strides_
,
in_shape
.
size
());
auto
out_shape
=
outputs_
.
at
(
0
)
->
shape
();
ComputeStrides
(
out_shape
.
data
(),
param
->
out_strides_
,
out_shape
.
size
());
return
RET_OK
;
}
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/arg_min_max.cc
浏览文件 @
756b8346
...
...
@@ -89,7 +89,7 @@ void ArgMinMaxTopknFp32(const float *input, float *output, const int *in_shape,
}
void
ArgMinMax
(
const
void
*
input
,
void
*
output
,
const
int
*
in_shape
,
ArgMinMaxParameter
*
param
)
{
if
(
param
->
topk_
==
1
)
{
if
(
param
->
topk_
==
1
&&
!
param
->
keep_dims_
)
{
ArgMinMaxTopk1
(
input
,
output
,
in_shape
,
param
);
return
;
}
...
...
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/argminmax_fp32_test.cc
浏览文件 @
756b8346
...
...
@@ -40,6 +40,34 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1) {
param
.
data_type_
=
43
;
param
.
dims_size_
=
2
;
param
.
get_max_
=
true
;
param
.
keep_dims_
=
false
;
ArgMinMax
(
in
.
data
(),
out
,
shape
.
data
(),
&
param
);
for
(
size_t
i
=
0
;
i
<
except_out
.
size
();
++
i
)
{
std
::
cout
<<
out
[
i
]
<<
" "
;
}
std
::
cout
<<
"
\n
"
;
CompareOutputData
(
out
,
except_out
.
data
(),
except_out
.
size
(),
0.000001
);
}
TEST_F
(
TestArgMinMaxTestFp32
,
ArgMaxTest1_keep_dim
)
{
std
::
vector
<
float
>
in
=
{
10
,
20
,
30
,
40
,
90
,
20
,
11
,
15
,
1
,
50
,
30
,
45
,
25
,
50
,
30
};
std
::
vector
<
float
>
except_out
=
{
2
,
2
,
0
,
2
,
0
};
std
::
vector
<
int
>
shape
=
{
3
,
5
};
float
out
[
5
];
ArgMinMaxParameter
param
;
param
.
topk_
=
1
;
param
.
out_value_
=
false
;
param
.
axis_
=
0
;
param
.
data_type_
=
43
;
param
.
dims_size_
=
2
;
param
.
get_max_
=
true
;
param
.
keep_dims_
=
true
;
param
.
arg_elements_
=
reinterpret_cast
<
ArgElement
*>
(
malloc
(
shape
[
param
.
axis_
]
*
sizeof
(
ArgElement
)));
std
::
vector
<
int
>
out_shape
=
{
1
,
5
};
ComputeStrides
(
shape
.
data
(),
param
.
in_strides_
,
shape
.
size
());
ComputeStrides
(
out_shape
.
data
(),
param
.
out_strides_
,
out_shape
.
size
());
ArgMinMax
(
in
.
data
(),
out
,
shape
.
data
(),
&
param
);
for
(
size_t
i
=
0
;
i
<
except_out
.
size
();
++
i
)
{
std
::
cout
<<
out
[
i
]
<<
" "
;
...
...
@@ -62,6 +90,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest2) {
param
.
data_type_
=
43
;
param
.
dims_size_
=
2
;
param
.
get_max_
=
true
;
param
.
keep_dims_
=
false
;
ArgMinMax
(
in
.
data
(),
out
,
shape
.
data
(),
&
param
);
CompareOutputData
(
out
,
except_out
.
data
(),
except_out
.
size
(),
0.000001
);
}
...
...
@@ -80,6 +109,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMinTest2) {
param
.
data_type_
=
43
;
param
.
dims_size_
=
2
;
param
.
get_max_
=
false
;
param
.
keep_dims_
=
false
;
ArgMinMax
(
in
.
data
(),
out
,
shape
.
data
(),
&
param
);
CompareOutputData
(
out
,
except_out
.
data
(),
except_out
.
size
(),
0.000001
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录